Skip to content

[FEAT] summarized chat completion context #6217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
364ddf6
FEAT: build Just draft
SongChiYoung Apr 5, 2025
29f29de
FEAT: End of design summarized cat completion
SongChiYoung Apr 5, 2025
03353ef
FEAT: add conditions
SongChiYoung Apr 6, 2025
fda9bed
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 6, 2025
2a651a5
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 6, 2025
e88443b
FEAT: Add TokenUsageMessageCompletion
SongChiYoung Apr 6, 2025
33e9595
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 7, 2025
ec63df3
feat: maybe, add all of conditions
SongChiYoung Apr 7, 2025
9e4fef9
FIX: fix errors
SongChiYoung Apr 7, 2025
f772ed4
FEAT: buffered_summary add for included summary func
SongChiYoung Apr 7, 2025
d063f15
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 7, 2025
b337c13
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 10, 2025
df60122
FEAT: add bufferd_summary_context
SongChiYoung Apr 10, 2025
37ef9bf
FEAT: serialize aware
SongChiYoung Apr 10, 2025
c536f8f
FEAT: pyright and mypy aware
SongChiYoung Apr 10, 2025
88894d3
FIX: Docstring fixed
SongChiYoung Apr 10, 2025
d8d18a8
TEST: add summary model context test
SongChiYoung Apr 10, 2025
90486e0
CHOR: test code pyright/mypy
SongChiYoung Apr 10, 2025
b8ddf8e
FEAT: add summary engines test cases
SongChiYoung Apr 10, 2025
5686a69
FEAT: test cov done
SongChiYoung Apr 10, 2025
daf0ee8
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 11, 2025
5ac3fb9
Merge branch 'main' into feature/_summarized_chat_completion_context
SongChiYoung Apr 12, 2025
4e3c53d
FEAT: LLM summary engine is now appear
SongChiYoung Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._summarized_chat_completion_context import (
SummarizedChatCompletionContext,
)
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)
from .conditions import (
ContextMessage,
SummarizngFunction,
TriggerMessage,
)

__all__ = [
"ChatCompletionContext",
Expand All @@ -13,4 +21,8 @@
"BufferedChatCompletionContext",
"TokenLimitedChatCompletionContext",
"HeadAndTailChatCompletionContext",
"SummarizedChatCompletionContext",
"ContextMessage",
"TriggerMessage",
"SummarizngFunction",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List

from pydantic import BaseModel
from typing_extensions import Self

from autogen_core import ComponentModel

from .._component_config import Component
from ..models import LLMMessage
from ..tools._base import BaseTool
from ._chat_completion_context import ChatCompletionContext
from .conditions import MessageCompletionCondition, SummarizngFunction, SummaryFunction


class SummarizedChatCompletionContextConfig(BaseModel):
summarizing_func: ComponentModel
summarizing_condition: ComponentModel
initial_messages: List[LLMMessage] | None = None
non_summarized_messages: List[LLMMessage] | None = None


class SummarizedChatCompletionContext(ChatCompletionContext, Component[SummarizedChatCompletionContextConfig]):
"""A summarized chat completion context that summarizes the messages in the context
using a summarizing function. The summarizing function is set at initialization.
The summarizing condition is used to determine when to summarize the messages.

Args:
summarizing_func (Callable[[List[LLMMessage]], List[LLMMessage]]): The function to summarize the messages.
summarizing_condition (MessageCompletionCondition): The condition to determine when to summarize the messages.
initial_messages (List[LLMMessage] | None): The initial messages.

Example:
.. code-block:: python

from typing import List

from autogen_core.model_context import SummarizedChatCompletionContext
from autogen_core.models import LLMMessage


def summarizing_func(messages: List[LLMMessage]) -> List[LLMMessage]:
# Implement your summarizing function here.
return messages


summarizing_condition = MessageCompletionCondition()

context = SummarizedChatCompletionContext(
summarizing_func=summarizing_func,
summarizing_condition=summarizing_condition,
)

.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_core.model_context import SummarizedChatCompletionContext
from autogen_core.model_context.conditions import MaxMessageCompletion
from autogen_ext.summary import buffered_summary, buffered_summarized_chat_completion_context


client = OpenAIChatCompletionClient(model="claude-3-haiku-20240307")

context = SummarizedChatCompletionContext(
summarizing_func=buffered_summary(buffer_count=2), summarizing_condition=MaxMessageCompletion(max_messages=2)
)

agent = AssistantAgent(
"helper", model_client=client, system_message="You are a helpful agent", model_context=context
)

"""

component_config_schema = SummarizedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.SummarizedChatCompletionContext"

def __init__(
self,
summarizing_func: SummaryFunction | SummarizngFunction,
summarizing_condition: MessageCompletionCondition,
initial_messages: List[LLMMessage] | None = None,
non_summarized_messages: List[LLMMessage] | None = None,
) -> None:
super().__init__(initial_messages)

self._non_summarized_messages: List[LLMMessage] = []
if non_summarized_messages is not None:
self._non_summarized_messages.extend(non_summarized_messages)

self._non_summarized_messages.extend(self._messages)

self._summarizing_func: SummaryFunction
if isinstance(summarizing_func, BaseTool):
# If the summarizing function is a tool, use it directly.
self._summarizing_func = summarizing_func
elif callable(summarizing_func):
self._summarizing_func = SummaryFunction(summarizing_func)
else:
raise ValueError("summarizing_func must be a callable or a tool.")
self._summarizing_condition = summarizing_condition

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""
self._non_summarized_messages.append(message)
await super().add_message(message)

# Check if the summarizing condition is met.
await self._summarizing_condition(self._messages)
if self._summarizing_condition.triggered:
# If the condition is met, summarize the messages.
await self.summary()
await self._summarizing_condition.reset()

async def get_messages(self) -> List[LLMMessage]:
return self._messages

async def summary(self) -> None:
"""Summarize the messages in the context using the summarizing function."""
summarized_message = self._summarizing_func.run(self._messages, self._non_summarized_messages)
self._messages = summarized_message

def _to_config(self) -> SummarizedChatCompletionContextConfig:
return SummarizedChatCompletionContextConfig(
summarizing_func=self._summarizing_func.dump_component(),
summarizing_condition=self._summarizing_condition.dump_component(),
initial_messages=self._initial_messages,
)

@classmethod
def _from_config(cls, config: SummarizedChatCompletionContextConfig) -> Self:
return cls(
summarizing_func=SummaryFunction.load_component(config.summarizing_func),
summarizing_condition=MessageCompletionCondition.load_component(config.summarizing_condition),
initial_messages=config.initial_messages,
non_summarized_messages=config.non_summarized_messages,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ._base_condition import (
AndMessageCompletionCondition,
MessageCompletionCondition,
MessageCompletionException,
OrMessageCompletionCondition,
)
from ._contidions import (
ExternalMessageCompletion,
FunctionCallMessageCompletion,
MaxMessageCompletion,
SourceMatchMessageCompletion,
StopMessageCompletion,
TextMentionMessageCompletion,
TextMessageMessageCompletion,
TimeoutMessageCompletion,
TokenUsageMessageCompletion,
)
from ._summary_function import (
SummaryFunction,
)
from ._types import (
ContextMessage,
SummarizngFunction,
TriggerMessage,
)

__all__ = [
"MessageCompletionCondition",
"AndMessageCompletionCondition",
"OrMessageCompletionCondition",
"MessageCompletionException",
"ContextMessage",
"TriggerMessage",
"SummarizngFunction",
"StopMessageCompletion",
"MaxMessageCompletion",
"TextMentionMessageCompletion",
"TokenUsageMessageCompletion",
"TimeoutMessageCompletion",
"ExternalMessageCompletion",
"SourceMatchMessageCompletion",
"TextMessageMessageCompletion",
"FunctionCallMessageCompletion",
"SummaryFunction",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from autogen_core import Component, ComponentBase, ComponentModel

from ._types import ContextMessage, TriggerMessage


class MessageCompletionException(BaseException): ...


class MessageCompletionCondition(ABC, ComponentBase[BaseModel]):
"""A stateful condition that determines when a message completion should be triggered.
A message completion condition is a callable that takes a sequence of ContextMessage objects
since the last time the condition was called, and returns a TriggerMessage if the
conversation should be terminated, or None otherwise.
Once a message completion condition has been reached, it must be reset before it can be used again.
Message completion conditions can be combined using the AND and OR operators.
Example:
.. code-block:: python
import asyncio
from autogen_core.model_context.conditions import (
MaxMessageCompletion,
TextMentionMessageCompletion,
)


async def main() -> None:
# Terminate the conversation after 10 turns or if the text "SUMMARY" is mentioned.
cond1 = MaxMessageCompletion(10) | TextMentionMessageCompletion("SUMMARY")
# Terminate the conversation after 10 turns and if the text "SUMMARY" is mentioned.
cond2 = MaxMessageCompletion(10) & TextMentionMessageCompletion("SUMMARY")
# ...
# Reset the message completion condition.
await cond1.reset()
await cond2.reset()


asyncio.run(main())
"""

component_type = "message_completion_condition"

@property
@abstractmethod
def triggered(self) -> bool:
"""Check if the trigger condition has been reached"""
...

@abstractmethod
async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
"""Check if the message completion should be triggered based on the messages received since the last call.

Args:
messages (Sequence[ContextMessage]): The messages received since the last call.
Returns:
TriggerMessage | None: The trigger message if the condition is met, or None if not.
Raises:
MessageCompletionException: If the message completion condition has already been reached."""
...

@abstractmethod
async def reset(self) -> None:
"""Reset the model completion condition."""
...

def __and__(self, other: "MessageCompletionCondition") -> "MessageCompletionCondition":
"""Combine two trigger conditions with an AND operation."""
return AndMessageCompletionCondition(self, other)

def __or__(self, other: "MessageCompletionCondition") -> "MessageCompletionCondition":
"""Combine two trigger conditions with an OR operation."""
return OrMessageCompletionCondition(self, other)


class AndMessageCompletionConditionConfig(BaseModel):
conditions: List[ComponentModel]


class AndMessageCompletionCondition(MessageCompletionCondition, Component[AndMessageCompletionConditionConfig]):
component_config_schema = AndMessageCompletionConditionConfig
component_type = "trigger"
component_provider_override = "autogen_core.model_context.conditions.AndMessageCompletionCondition"

def __init__(self, *conditions: MessageCompletionCondition) -> None:
self._conditions = conditions
self._trigger_messages: List[TriggerMessage] = []

@property
def triggered(self) -> bool:
return all(condition.triggered for condition in self._conditions)

async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
if self.triggered:
raise MessageCompletionException("Message completion condition has already been reached.")
# Check all remaining conditions.
trigger_messages = await asyncio.gather(
*[condition(messages) for condition in self._conditions if not condition.triggered]
)
# Collect stop messages.
for trigger_message in trigger_messages:
if trigger_message is not None:
self._trigger_messages.append(trigger_message)
if any(trigger_message is None for trigger_message in trigger_messages):
# If any remaining condition has not reached termination, it is not terminated.
return None
content = ", ".join(trigger_message.content for trigger_message in self._trigger_messages)
source = ", ".join(trigger_message.source for trigger_message in self._trigger_messages)
return TriggerMessage(content=content, source=source)

async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()
self._trigger_messages.clear()

def _to_config(self) -> AndMessageCompletionConditionConfig:
"""Convert the AND trigger condition to a config."""
return AndMessageCompletionConditionConfig(
conditions=[condition.dump_component() for condition in self._conditions]
)

@classmethod
def _from_config(cls, config: AndMessageCompletionConditionConfig) -> Self:
"""Create an AND trigger condition from a config."""
conditions = [
MessageCompletionCondition.load_component(condition_model) for condition_model in config.conditions
]
return cls(*conditions)


class OrMessageCompletionConditionConfig(BaseModel):
conditions: List[ComponentModel]
"""List of termination conditions where any one being satisfied is sufficient."""


class OrMessageCompletionCondition(MessageCompletionCondition, Component[OrMessageCompletionConditionConfig]):
component_config_schema = OrMessageCompletionConditionConfig
component_type = "trigger"
component_provider_override = "autogen_core.model_context.conditions.OrTerminationCondition"

def __init__(self, *conditions: MessageCompletionCondition) -> None:
self._conditions = conditions

@property
def triggered(self) -> bool:
return any(condition.triggered for condition in self._conditions)

async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
if self.triggered:
raise RuntimeError("Message completion condition has already been reached")
trigger_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
trigger_messages_filter = [
trigger_message for trigger_message in trigger_messages if trigger_message is not None
]
if len(trigger_messages_filter) > 0:
content = ", ".join(trigger_message.content for trigger_message in trigger_messages_filter)
source = ", ".join(trigger_message.source for trigger_message in trigger_messages_filter)
return TriggerMessage(content=content, source=source)
return None

async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()

def _to_config(self) -> OrMessageCompletionConditionConfig:
"""Convert the OR trigger condition to a config."""
return OrMessageCompletionConditionConfig(
conditions=[condition.dump_component() for condition in self._conditions]
)

@classmethod
def _from_config(cls, config: OrMessageCompletionConditionConfig) -> Self:
"""Create an OR trigger condition from a config."""
conditions = [
MessageCompletionCondition.load_component(condition_model) for condition_model in config.conditions
]
return cls(*conditions)
Loading