Skip to content

Commit

Permalink
[Refactor] Transforms Utils (#2863)
Browse files Browse the repository at this point in the history
* wip

* tests + docstrings

* improves tests

* fix import
  • Loading branch information
WaelKarkoub authored Jun 6, 2024
1 parent 102d36d commit 8564bd4
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 82 deletions.
115 changes: 33 additions & 82 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

Expand All @@ -8,8 +7,9 @@

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType

from . import transforms_util
from .text_compressors import LLMLingua, TextCompressor


Expand Down Expand Up @@ -169,7 +169,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand All @@ -178,13 +178,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

for msg in reversed(temp_messages):
# Some messages may not have content.
if not _is_content_right_type(msg.get("content")):
if not transforms_util.is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue

if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue

expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
Expand All @@ -199,7 +199,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
break

msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = _count_tokens(msg["content"])
msg_tokens = transforms_util.count_text_tokens(msg["content"])

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
Expand All @@ -209,10 +209,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)

if post_transform_messages_tokens < pre_transform_messages_tokens:
Expand Down Expand Up @@ -349,31 +349,32 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not _is_content_right_type(message.get("content")):
if not transforms_util.is_content_right_type(message.get("content")):
continue

if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue

if _is_content_text_empty(message["content"]):
if transforms_util.is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
if cached_content is not None:
savings, compressed_content = cached_content
message["content"], savings = cached_content
else:
savings, compressed_content = self._compress(message["content"])
message["content"], savings = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)

message["content"] = compressed_content
assert isinstance(savings, int)
total_savings += savings

self._recent_tokens_savings = total_savings
Expand All @@ -385,88 +386,38 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content
return content, 0

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
for item in content:
if isinstance(item, dict) and "text" in item:
item["text"], savings = self._compress_text(item["text"])
tokens_saved += savings

elif isinstance(item, str):
item, savings = self._compress_text(item)
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
return content, tokens_saved

def _compress_text(self, text: str) -> Tuple[str, int]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"
return compressed_text["compressed_prompt"], savings

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False


def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
114 changes: 114 additions & 0 deletions autogen/agentchat/contrib/capabilities/transforms_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any, Dict, Hashable, List, Optional, Tuple

from autogen import token_count_utils
from autogen.cache.abstract_cache_base import AbstractCache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType


def cache_key(content: MessageContentType, *args: Hashable) -> str:
"""Calculates the cache key for the given message content and any other hashable args.
Args:
content (MessageContentType): The message content to calculate the cache key for.
*args: Any additional hashable args to include in the cache key.
"""
str_keys = [str(key) for key in (content, *args)]
return "".join(str_keys)


def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
"""Retrieves cachedd content from the cache.
Args:
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
key (str): The key to retrieve the content from.
"""
if cache:
cached_value = cache.get(key)
if cached_value:
return cached_value


def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
"""Sets content into the cache.
Args:
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
key (str): The key to set the content into.
content (MessageContentType): The message content to set into the cache.
*extra_values: Additional values to be passed to the cache.
"""
if cache:
cache_value = (content, *extra_values)
cache.set(key, cache_value)


def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
Args:
messages (List[Dict]): A list of messages to check.
"""
if not min_tokens:
return True

messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def count_text_tokens(content: MessageContentType) -> int:
"""Calculates the number of text tokens in the given message content.
Args:
content (MessageContentType): The message content to calculate the number of text tokens for.
"""
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
token_count += token_count_utils.count_token(item)
else:
token_count += count_text_tokens(item.get("text", ""))
return token_count


def is_content_right_type(content: Any) -> bool:
"""A helper function to check if the passed in content is of the right type."""
return isinstance(content, (str, list))


def is_content_text_empty(content: MessageContentType) -> bool:
"""Checks if the content of the message does not contain any text.
Args:
content (MessageContentType): The message content to check.
"""
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
texts.append(item.get("text", ""))
return not any(texts)
else:
return True


def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
"""Validates whether the transform should be applied according to the filter dictionary.
Args:
message (Dict[str, Any]): The message to validate.
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
exclude (bool): Whether to exclude messages that match the filter dictionary.
"""
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
2 changes: 2 additions & 0 deletions autogen/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Literal, TypedDict, Union

MessageContentType = Union[str, List[Union[Dict, str]], None]


class UserMessageTextContentPart(TypedDict):
type: Literal["text"]
Expand Down
Loading

0 comments on commit 8564bd4

Please sign in to comment.