Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Modality Transforms #2836

Closed
wants to merge 15 commits into from
5 changes: 4 additions & 1 deletion autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
register_reply_position: int = 2,
output_prompt_template: str = "I generated an image with the prompt: {prompt}",
):
"""
Args:
Expand All @@ -165,13 +166,15 @@ def __init__(
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
This capability registers a new reply function to handle messages with image generation requests.
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
output_prompt_template (str): The template for the output prompt.
"""
self._image_generator = image_generator
self._cache = cache
self._text_analyzer_llm_config = text_analyzer_llm_config
self._text_analyzer_instructions = text_analyzer_instructions
self._verbosity = verbosity
self._register_reply_position = register_reply_position
self._output_prompt_template = output_prompt_template

self._agent: Optional[ConversableAgent] = None
self._text_analyzer: Optional[TextAnalyzerAgent] = None
Expand Down Expand Up @@ -271,7 +274,7 @@ def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str:
def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
{"type": "text", "text": self._output_prompt_template.format(prompt=prompt)},
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
]
}
Expand Down
23 changes: 23 additions & 0 deletions autogen/agentchat/contrib/capabilities/image_captioners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, List, Literal, Optional, Protocol, Union

from transformers import pipeline


class ImageCaptioner(Protocol):
def caption_image(self, image_url: str) -> str: ...


class HuggingFaceImageCaptioner:
def __init__(
self,
model: str = "Salesforce/blip-image-captioning-base",
):
self._captioner = pipeline("image-to-text", model=model)

def caption_image(self, image_url: str) -> str:
output_caption = ""
caption = self._captioner(image_url)
if isinstance(caption, list) and len(caption) > 0 and isinstance(caption[0], dict):
output_caption = caption[0].get("generated_text", "")

return output_caption
239 changes: 239 additions & 0 deletions autogen/agentchat/contrib/capabilities/modality_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from typing import Dict, List, Literal, Optional, Sequence, Set, Tuple, Union

from autogen.agentchat.contrib import img_utils
from autogen.agentchat.utils import parse_tags_from_content, replace_tag_in_content
from autogen.cache.cache import AbstractCache, Cache
from autogen.types import MessageContentType

from .image_captioners import ImageCaptioner

ModalitiesType = Literal["text", "image", "video", "audio"]
MODALITIES_ALIAS: Dict[ModalitiesType, List[str]] = {
"text": ["text"],
"image": ["image", "image_url"],
"video": ["video"],
"audio": ["audio"],
}


class ImageModality:
def __init__(
self,
image_captioner: Optional[ImageCaptioner] = None,
caption_template: str = "(You received an image and here is the caption: {caption}.)",
agent_has_image_modality: bool = False,
drop_unsupported_message_format: bool = True,
cache: AbstractCache = Cache.disk(),
):
self._validate_modality_support(agent_has_image_modality, drop_unsupported_message_format, image_captioner)

self._captioner = image_captioner
self._caption_template = caption_template
self._agent_has_image_modality = agent_has_image_modality

drop_unsupported_transform = _drop_unsupported_factory(
unsupported_agent_modalities=["image"] if not agent_has_image_modality else None,
modalities_alias=MODALITIES_ALIAS,
)

self._drop_unsupported = drop_unsupported_transform if drop_unsupported_message_format else None
self._cache = cache

self._n_tags_converted = 0
self._n_tags_captioned = 0
self._n_images_captioned = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
self._n_tags_converted = 0
self._n_tags_captioned = 0
self._n_images_captioned = 0

for message in messages:
if not message.get("content") or message["content"] is None:
return messages

if not isinstance(message["content"], (list, str)):
return messages

if self._agent_has_image_modality:
message["content"] = self._convert_tags_to_multimodal_content(message["content"])

else:
assert self._captioner, "Must provide an image captioner to convert images to text."
message["content"] = self._replace_images_with_captions(message["content"])
message["content"] = self._replace_tags_with_captions(message["content"])

if self._drop_unsupported:
# We want to only drop the image types
messages = self._drop_unsupported.apply_transform(messages)

return messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
logs = []
if self._n_tags_converted > 0:
logs.append(f"Converted {self._n_tags_converted} image tags to multimodal content.")

if self._n_tags_captioned > 0:
logs.append(f"Captioned {self._n_tags_captioned} image tags to text.")

if self._n_images_captioned > 0:
logs.append(f"Captioned {self._n_images_captioned} images to text.")

if len(logs) > 0:
return "\n".join(logs), True
else:
return "No images were found.", False

def _convert_tags_to_multimodal_content(self, content: MessageContentType) -> List[Union[Dict, str]]:
initial_image_count = _count_content_type(content, "image_url")

if isinstance(content, str):
modified_content = img_utils.gpt4v_formatter(content)
current_image_count = _count_content_type(modified_content, "image_url")
self._n_tags_converted += current_image_count - initial_image_count
return img_utils.gpt4v_formatter(content)

modified_content = []
if isinstance(content, list):
for item in content:
if isinstance(item, str):
modified_content.extend(img_utils.gpt4v_formatter(item))
else:
if "text" in item:
modified_content.extend(img_utils.gpt4v_formatter(item["text"]))
else:
modified_content.append(item)

current_image_count = _count_content_type(modified_content, "image_url")
self._n_tags_converted += current_image_count - initial_image_count
return modified_content

def _replace_tags_with_captions(self, content: MessageContentType) -> Union[List[Union[Dict, str]], str]:
assert self._captioner
for tag in parse_tags_from_content("img", content):
try:
caption = self._captioner.caption_image(tag["attr"]["src"])
replacement_text = self._caption_template.format(caption=caption)
self._n_tags_captioned += 1
except Exception:
replacement_text = (
"(You failed to convert the image tag to text. "
f"Possibly due to invalid image source {tag['attr']['src']}.)"
)
content = replace_tag_in_content(tag, content, replacement_text)
return content

def _replace_images_with_captions(self, content: MessageContentType) -> List[Union[Dict, str]]:
assert self._captioner

if isinstance(content, str):
return [content]

if isinstance(content, list) and len(content) > 0 and isinstance(content[0], str):
return content

output_captions = ""
img_number = 1
txt_idx = None
for idx, item in enumerate(content):
if not isinstance(item, dict):
continue

if item.get("type") == "text":
txt_idx = idx

if item.get("type") == "image_url":
img = item["image_url"]["url"]
try:
caption = self._captioner.caption_image(img)
output_captions += self._caption_template.format(idx=img_number, caption=caption) + "\n"
self._n_images_captioned += 1
except Exception:
output_captions += f"(Failed to generate caption for image {img_number}.)\n"

img_number += 1

output_captions = output_captions.format(total=img_number - 1)

if output_captions == "":
return content

if txt_idx is not None:
assert len(content) > 0
content[txt_idx]["text"] += output_captions
else:
content.insert(0, {"type": "text", "text": output_captions})

return content

def _validate_modality_support(
self, agent_has_image_modality: bool, drop_unsupported: bool, image_captioner: Optional[ImageCaptioner]
) -> None:
if agent_has_image_modality and drop_unsupported:
raise ValueError("Cannot drop unsupported modalities when the agent has an image modality.")

if not image_captioner and not agent_has_image_modality:
raise ValueError("Must provide an image captioner when the agent does not have an image modality.")


class DropUnsupportedModalities:
def __init__(self, supported_modalities: Sequence[ModalitiesType] = list()):
self._supported_modalities = _expand_supported_modalities(supported_modalities, MODALITIES_ALIAS)

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
for message in messages:
if message.get("content") is None or isinstance(message["content"], str):
continue

if not isinstance(message["content"], list):
continue

new_content = []
for item in message["content"]:
if not isinstance(item, dict):
continue

if item.get("type") in self._supported_modalities:
new_content.append(item)

message["content"] = new_content
return messages


def _drop_unsupported_factory(
unsupported_agent_modalities: Optional[List[ModalitiesType]], modalities_alias: Dict[ModalitiesType, List[str]]
) -> DropUnsupportedModalities:
""" """
if unsupported_agent_modalities is not None:
supported_modalities: List[ModalitiesType] = [
modal for modal in modalities_alias.keys() if modal not in unsupported_agent_modalities
]
else:
supported_modalities: List[ModalitiesType] = list(modalities_alias.keys())

return DropUnsupportedModalities(supported_modalities)


def _expand_supported_modalities(
supported_modalities: Sequence[ModalitiesType], modalities_mapping: Dict[ModalitiesType, List[str]]
) -> Set[str]:
expanded_modalities = set()
for modality in supported_modalities:
expanded_modalities.update(modalities_mapping.get(modality, []))
return expanded_modalities


def _count_content_type(content: MessageContentType, message_type: str) -> int:
total_count = 0
if isinstance(content, str) and message_type == "text":
total_count += 1

if isinstance(content, list) and len(content) > 0:
for item in content:
if isinstance(item, str) and message_type == "text":
total_count += 1
elif isinstance(item, dict) and item.get("type") == message_type:
total_count += 1

return total_count
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
from io import BytesIO
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Literal, Tuple, Union

import requests
from PIL import Image
Expand Down Expand Up @@ -163,7 +163,7 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri


def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
def gpt4v_formatter(prompt: str, img_format: Literal["uri", "url", "pil"] = "uri") -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.

Expand Down
35 changes: 32 additions & 3 deletions autogen/agentchat/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import re
from typing import Any, Callable, Dict, List, Union

from autogen.types import MessageContentType

from .agent import Agent


Expand Down Expand Up @@ -96,7 +99,7 @@ def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, An
}


def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]:
def parse_tags_from_content(tag: str, content: MessageContentType) -> List[Dict[str, Dict[str, str]]]:
"""Parses HTML style tags from message contents.

The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is
Expand Down Expand Up @@ -128,8 +131,11 @@ def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]])
# Handles case for multimodal messages.
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
results.extend(_parse_tags_from_text(tag, item["text"]))
if isinstance(item, str):
results.extend(_parse_tags_from_text(tag, item))
else:
if item.get("type") == "text":
results.extend(_parse_tags_from_text(tag, item["text"]))
else:
raise ValueError(f"content must be str or list, but got {type(content)}")

Expand Down Expand Up @@ -174,6 +180,29 @@ def _append_src_value(content, value):
return content


def replace_tag_in_content(
tag: Dict[str, Any], content: Union[List[Union[Dict, str]], str], replacement_text: str
) -> Union[List[Union[Dict, str]], str]:
content = copy.deepcopy(content)
if isinstance(content, List):
return _multimodal_replace_tag_in_content(content, tag, replacement_text)
else:
return content.replace(tag["match"].group(), replacement_text)


def _multimodal_replace_tag_in_content(
content: List[Union[Dict, str]], tag: Dict[str, Any], replacement_text: str
) -> List[Union[Dict, str]]:
modified_content = []
for item in content:
if isinstance(item, str):
modified_content.append(item.replace(tag["match"].group(), replacement_text))
else:
item["text"] = item["text"].replace(tag["match"].group(), replacement_text)
modified_content.append(item)
return modified_content


def _reconstruct_attributes(attrs: List[str]) -> List[str]:
"""Reconstructs attributes from a list of strings where some attributes may be split across multiple elements."""

Expand Down
Loading