Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple images #151

Merged
merged 13 commits into from
Dec 14, 2024
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here
strategy:
matrix:
python-version: [3.11, 3.13] # Our min and max supported Python versions
ludomitch marked this conversation as resolved.
Show resolved Hide resolved
python-version: [3.11, 3.12] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.11, 3.13] # Our min and max supported Python versions
python-version: [3.11, 3.12] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.13
3.12
21 changes: 14 additions & 7 deletions src/aviary/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, ClassVar, Self

import numpy as np
ludomitch marked this conversation as resolved.
Show resolved Hide resolved
from pydantic import BaseModel, Field, field_validator, model_validator

from aviary.utils import encode_image_to_base64
from aviary.utils import check_if_valid_base64, encode_image_to_base64

if TYPE_CHECKING:
from logging import LogRecord

import numpy as np


class Message(BaseModel):
DEFAULT_ROLE: ClassVar[str] = "user"
Expand Down Expand Up @@ -124,16 +123,24 @@ def create_message(
cls,
role: str = DEFAULT_ROLE,
text: str | None = None,
image: np.ndarray | None = None,
images: list[np.ndarray | str] | str | np.ndarray | None = None,
ludomitch marked this conversation as resolved.
Show resolved Hide resolved
) -> Self:
# Assume no image, and update to image if present
# Assume no images, and update to images if present
content: str | list[dict] | None = text
if image is not None:
if images is not None:
if isinstance(images, str | np.ndarray):
images = [images]
content = [
{
"type": "image_url",
"image_url": {"url": encode_image_to_base64(image)},
"image_url": {
"url": encode_image_to_base64(image)
# If image is a string, assume it's already a base64 encoded image
if isinstance(image, np.ndarray)
else check_if_valid_base64(image)
},
}
for image in images
]
if text is not None:
content.append({"type": "text", "text": text})
Expand Down
9 changes: 9 additions & 0 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def encode_image_to_base64(img: "np.ndarray") -> str:
)


def check_if_valid_base64(image: str) -> str:
"""Check if the input string is a valid base64 encoded image."""
try:
base64.b64decode(image)
except Exception as err:
raise ValueError("Invalid base64 encoded image") from err
return image


def is_coroutine_callable(obj) -> bool:
"""Get if the input object is awaitable."""
if inspect.isfunction(obj) or inspect.ismethod(obj):
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/test_images/sample_image1.b64
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

104 changes: 91 additions & 13 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pathlib

import numpy as np
import pytest
Expand All @@ -11,6 +12,12 @@
ToolResponseMessage,
)

FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" / "test_images"


def load_base64_image(filename: str) -> str:
return (FIXTURES_DIR / filename).read_text().strip()


class TestMessage:
def test_roles(self) -> None:
Expand Down Expand Up @@ -112,20 +119,91 @@ def test_str(self, message: Message, expected: str) -> None:
def test_dump(self, message: Message, expected: dict) -> None:
assert message.model_dump(exclude_none=True) == expected

def test_image_message(self) -> None:
# An RGB image of a red square
image = np.zeros((32, 32, 3), dtype=np.uint8)
image[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB
message_text = "What color is this square? Respond only with the color name."
message_with_image = Message.create_message(text=message_text, image=image)
assert message_with_image.content
specialized_content = json.loads(message_with_image.content)
assert len(specialized_content) == 2
text_idx, image_idx = (
(0, 1) if specialized_content[0]["type"] == "text" else (1, 0)
)
@pytest.mark.parametrize(
("images", "message_text", "expected_error", "expected_content_length"),
[
# Case 1: Invalid base64 image should raise error
(
[
np.zeros((32, 32, 3), dtype=np.uint8), # red square
"_base64_content", # invalid base64
],
"What color are these squares? List each color.",
"Invalid base64 encoded image",
None,
),
# Case 2: Valid images should work
(
[
np.zeros((32, 32, 3), dtype=np.uint8), # red square
load_base64_image("sample_image1.b64"),
],
"What color are these squares? List each color.",
None,
3, # 2 images + 1 text
),
# Case 3: A numpy array in non-list formatshould be converted to a base64 encoded image
(
np.zeros((32, 32, 3), dtype=np.uint8), # red square
"What color is this square?",
None,
2, # 1 image + 1 text
),
# Case 4: A string should be converted to a base64 encoded image
(
load_base64_image("sample_image1.b64"),
"What color is this square?",
None,
2, # 1 image + 1 text
),
],
)
def test_image_message(
self,
images: list[np.ndarray | str] | np.ndarray | str,
message_text: str,
expected_error: str | None,
expected_content_length: int | None,
) -> None:
# Set red color for numpy array if present
for img in images:
if isinstance(img, np.ndarray):
img[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB

if expected_error:
with pytest.raises(ValueError, match=expected_error):
Message.create_message(text=message_text, images=images)
return

message_with_images = Message.create_message(text=message_text, images=images)
assert message_with_images.content
specialized_content = json.loads(message_with_images.content)
assert len(specialized_content) == expected_content_length

# Find indices of each content type
image_indices = []
text_idx = None
for i, content in enumerate(specialized_content):
if content["type"] == "image_url":
image_indices.append(i)
else:
text_idx = i

if isinstance(images, list):
assert len(image_indices) == len(images)
else:
assert len(image_indices) == 1
assert text_idx is not None
assert specialized_content[text_idx]["text"] == message_text
assert "image_url" in specialized_content[image_idx]

# Check both images are properly formatted
for idx in image_indices:
assert "image_url" in specialized_content[idx]
assert "url" in specialized_content[idx]["image_url"]
# Both images should be base64 encoded
assert specialized_content[idx]["image_url"]["url"].startswith(
"data:image/"
)


class TestToolRequestMessage:
Expand Down
Loading