From 349289fbbd7192bbfab2cd03a650678b19438e95 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 11:24:14 -0800 Subject: [PATCH 01/13] message update + tests --- src/aviary/message.py | 17 +++++++++++----- tests/test_messages.py | 46 +++++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/src/aviary/message.py b/src/aviary/message.py index fce01a14..92c3081a 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -1,6 +1,8 @@ from __future__ import annotations import json +import numpy as np + from collections.abc import Iterable from typing import TYPE_CHECKING, ClassVar, Self @@ -11,7 +13,6 @@ if TYPE_CHECKING: from logging import LogRecord - import numpy as np class Message(BaseModel): @@ -124,16 +125,22 @@ def create_message( cls, role: str = DEFAULT_ROLE, text: str | None = None, - image: np.ndarray | None = None, + images: list[np.ndarray | str] | None = None, ) -> Self: - # Assume no image, and update to image if present + # Assume no images, and update to images if present content: str | list[dict] | None = text - if image is not None: + if images is not None: content = [ { "type": "image_url", - "image_url": {"url": encode_image_to_base64(image)}, + "image_url": { + "url": encode_image_to_base64(image) + # If image is a string, assume it's already a base64 encoded image + if isinstance(image, np.ndarray) + else image + }, } + for image in images ] if text is not None: content.append({"type": "text", "text": text}) diff --git a/tests/test_messages.py b/tests/test_messages.py index 10c6af54..f3a7c1b2 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -114,18 +114,44 @@ def test_dump(self, message: Message, expected: dict) -> None: def test_image_message(self) -> None: # An RGB image of a red square - image = np.zeros((32, 32, 3), dtype=np.uint8) - image[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB - message_text = "What color is this square? Respond only with the color name." - message_with_image = Message.create_message(text=message_text, image=image) - assert message_with_image.content - specialized_content = json.loads(message_with_image.content) - assert len(specialized_content) == 2 - text_idx, image_idx = ( - (0, 1) if specialized_content[0]["type"] == "text" else (1, 0) + red_square = np.zeros((32, 32, 3), dtype=np.uint8) + red_square[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB + + # A pre-encoded base64 image (simulated) + encoded_image = "data:image/jpeg;base64,fake_base64_content" + + message_text = "What color are these squares? List each color." + message_with_images = Message.create_message( + text=message_text, + images=[red_square, encoded_image] ) + + assert message_with_images.content + specialized_content = json.loads(message_with_images.content) + assert len(specialized_content) == 3 # 2 images + 1 text + + # Find indices of each content type + image_indices = [] + text_idx = None + for i, content in enumerate(specialized_content): + if content["type"] == "image_url": + image_indices.append(i) + else: + text_idx = i + + assert len(image_indices) == 2 + assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text - assert "image_url" in specialized_content[image_idx] + + # Check both images are properly formatted + for idx in image_indices: + assert "image_url" in specialized_content[idx] + assert "url" in specialized_content[idx]["image_url"] + # First image should be base64 encoded, second should be the raw string + if idx == image_indices[0]: + assert specialized_content[idx]["image_url"]["url"].startswith("data:image/") + else: + assert specialized_content[idx]["image_url"]["url"] == encoded_image class TestToolRequestMessage: From aec552abd5ad5730a80535dbc18428bc9003519c Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 11:29:32 -0800 Subject: [PATCH 02/13] pre-commit run --- src/aviary/message.py | 4 +--- tests/test_messages.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/aviary/message.py b/src/aviary/message.py index 92c3081a..402e8498 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -1,11 +1,10 @@ from __future__ import annotations import json -import numpy as np - from collections.abc import Iterable from typing import TYPE_CHECKING, ClassVar, Self +import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator from aviary.utils import encode_image_to_base64 @@ -14,7 +13,6 @@ from logging import LogRecord - class Message(BaseModel): DEFAULT_ROLE: ClassVar[str] = "user" VALID_ROLES: ClassVar[set[str]] = { diff --git a/tests/test_messages.py b/tests/test_messages.py index f3a7c1b2..8425c015 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -116,20 +116,19 @@ def test_image_message(self) -> None: # An RGB image of a red square red_square = np.zeros((32, 32, 3), dtype=np.uint8) red_square[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB - + # A pre-encoded base64 image (simulated) encoded_image = "data:image/jpeg;base64,fake_base64_content" - + message_text = "What color are these squares? List each color." message_with_images = Message.create_message( - text=message_text, - images=[red_square, encoded_image] + text=message_text, images=[red_square, encoded_image] ) - + assert message_with_images.content specialized_content = json.loads(message_with_images.content) assert len(specialized_content) == 3 # 2 images + 1 text - + # Find indices of each content type image_indices = [] text_idx = None @@ -138,18 +137,20 @@ def test_image_message(self) -> None: image_indices.append(i) else: text_idx = i - + assert len(image_indices) == 2 assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text - + # Check both images are properly formatted for idx in image_indices: assert "image_url" in specialized_content[idx] assert "url" in specialized_content[idx]["image_url"] # First image should be base64 encoded, second should be the raw string if idx == image_indices[0]: - assert specialized_content[idx]["image_url"]["url"].startswith("data:image/") + assert specialized_content[idx]["image_url"]["url"].startswith( + "data:image/" + ) else: assert specialized_content[idx]["image_url"]["url"] == encoded_image From 612be2f8bddf111e7de9c9e5eb0ae48e9ff29cae Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 11:37:49 -0800 Subject: [PATCH 03/13] update lock --- uv.lock | 64 +++++---------------------------------------------------- 1 file changed, 5 insertions(+), 59 deletions(-) diff --git a/uv.lock b/uv.lock index 6436548e..d7248ce9 100644 --- a/uv.lock +++ b/uv.lock @@ -177,7 +177,7 @@ wheels = [ [[package]] name = "aviary-gsm8k" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev1+g602cf53.d20241213" source = { editable = "packages/gsm8k" } dependencies = [ { name = "datasets" }, @@ -200,7 +200,7 @@ requires-dist = [ [[package]] name = "aviary-hotpotqa" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev1+g602cf53.d20241213" source = { editable = "packages/hotpotqa" } dependencies = [ { name = "beautifulsoup4" }, @@ -825,7 +825,7 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.11.1.dev4+g1e8ea27.d20241204" +version = "0.12.1.dev1+g602cf53.d20241213" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, @@ -899,65 +899,11 @@ xml = [ [package.dev-dependencies] codeflash = [ - { name = "aviary-gsm8k", extra = ["typing"] }, - { name = "aviary-hotpotqa" }, - { name = "boto3-stubs", extra = ["s3"] }, - { name = "click" }, - { name = "cloudpickle" }, { name = "codeflash" }, - { name = "dicttoxml" }, - { name = "fastapi" }, - { name = "httpx" }, - { name = "ipython" }, - { name = "litellm" }, - { name = "mypy" }, - { name = "numpy" }, - { name = "pillow" }, - { name = "pre-commit" }, - { name = "pydantic" }, - { name = "pylint" }, - { name = "pylint-pydantic" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-recording" }, - { name = "pytest-subtests" }, - { name = "pytest-sugar" }, - { name = "pytest-timer", extra = ["colorama"] }, - { name = "pytest-xdist" }, - { name = "refurb" }, - { name = "typeguard" }, - { name = "types-pillow" }, - { name = "uvicorn" }, + { name = "fhaviary", extra = ["dev"] }, ] dev = [ - { name = "aviary-gsm8k", extra = ["typing"] }, - { name = "aviary-hotpotqa" }, - { name = "boto3-stubs", extra = ["s3"] }, - { name = "click" }, - { name = "cloudpickle" }, - { name = "dicttoxml" }, - { name = "fastapi" }, - { name = "httpx" }, - { name = "ipython" }, - { name = "litellm" }, - { name = "mypy" }, - { name = "numpy" }, - { name = "pillow" }, - { name = "pre-commit" }, - { name = "pydantic" }, - { name = "pylint" }, - { name = "pylint-pydantic" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-recording" }, - { name = "pytest-subtests" }, - { name = "pytest-sugar" }, - { name = "pytest-timer", extra = ["colorama"] }, - { name = "pytest-xdist" }, - { name = "refurb" }, - { name = "typeguard" }, - { name = "types-pillow" }, - { name = "uvicorn" }, + { name = "fhaviary", extra = ["dev"] }, ] [package.metadata] From cd54f5a71eb3ae1bde1956eb6d449839d2254e1a Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 11:45:33 -0800 Subject: [PATCH 04/13] Revert "update lock" This reverts commit 612be2f8bddf111e7de9c9e5eb0ae48e9ff29cae. --- uv.lock | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/uv.lock b/uv.lock index d7248ce9..6436548e 100644 --- a/uv.lock +++ b/uv.lock @@ -177,7 +177,7 @@ wheels = [ [[package]] name = "aviary-gsm8k" -version = "0.12.1.dev1+g602cf53.d20241213" +version = "0.11.1.dev3+gef1e2d5.d20241204" source = { editable = "packages/gsm8k" } dependencies = [ { name = "datasets" }, @@ -200,7 +200,7 @@ requires-dist = [ [[package]] name = "aviary-hotpotqa" -version = "0.12.1.dev1+g602cf53.d20241213" +version = "0.11.1.dev3+gef1e2d5.d20241204" source = { editable = "packages/hotpotqa" } dependencies = [ { name = "beautifulsoup4" }, @@ -825,7 +825,7 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.12.1.dev1+g602cf53.d20241213" +version = "0.11.1.dev4+g1e8ea27.d20241204" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, @@ -899,11 +899,65 @@ xml = [ [package.dev-dependencies] codeflash = [ + { name = "aviary-gsm8k", extra = ["typing"] }, + { name = "aviary-hotpotqa" }, + { name = "boto3-stubs", extra = ["s3"] }, + { name = "click" }, + { name = "cloudpickle" }, { name = "codeflash" }, - { name = "fhaviary", extra = ["dev"] }, + { name = "dicttoxml" }, + { name = "fastapi" }, + { name = "httpx" }, + { name = "ipython" }, + { name = "litellm" }, + { name = "mypy" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "pre-commit" }, + { name = "pydantic" }, + { name = "pylint" }, + { name = "pylint-pydantic" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-recording" }, + { name = "pytest-subtests" }, + { name = "pytest-sugar" }, + { name = "pytest-timer", extra = ["colorama"] }, + { name = "pytest-xdist" }, + { name = "refurb" }, + { name = "typeguard" }, + { name = "types-pillow" }, + { name = "uvicorn" }, ] dev = [ - { name = "fhaviary", extra = ["dev"] }, + { name = "aviary-gsm8k", extra = ["typing"] }, + { name = "aviary-hotpotqa" }, + { name = "boto3-stubs", extra = ["s3"] }, + { name = "click" }, + { name = "cloudpickle" }, + { name = "dicttoxml" }, + { name = "fastapi" }, + { name = "httpx" }, + { name = "ipython" }, + { name = "litellm" }, + { name = "mypy" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "pre-commit" }, + { name = "pydantic" }, + { name = "pylint" }, + { name = "pylint-pydantic" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-recording" }, + { name = "pytest-subtests" }, + { name = "pytest-sugar" }, + { name = "pytest-timer", extra = ["colorama"] }, + { name = "pytest-xdist" }, + { name = "refurb" }, + { name = "typeguard" }, + { name = "types-pillow" }, + { name = "uvicorn" }, ] [package.metadata] From 6cd0c4743bc526b2a1c72189c7e5290a0205356c Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 13:57:26 -0800 Subject: [PATCH 05/13] renovate cleanup --- .python-version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.python-version b/.python-version index 24ee5b1b..e4fba218 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.13 +3.12 From 1c13b7a2f2cfaecf2f23e3a59ef876f1b3b079ae Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:27:31 -0800 Subject: [PATCH 06/13] update gh workflow versions and expand scope of images support --- .github/workflows/tests.yml | 4 +- src/aviary/message.py | 8 ++-- src/aviary/utils.py | 10 +++++ tests/test_messages.py | 88 +++++++++++++++++++++++++++---------- 4 files changed, 83 insertions(+), 27 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f23a7153..2553ee5f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here strategy: matrix: - python-version: [3.11, 3.13] # Our min and max supported Python versions + python-version: [3.11, 3.12] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: @@ -27,7 +27,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.11, 3.13] # Our min and max supported Python versions + python-version: [3.11, 3.12] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: diff --git a/src/aviary/message.py b/src/aviary/message.py index 402e8498..95440898 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -7,7 +7,7 @@ import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator -from aviary.utils import encode_image_to_base64 +from aviary.utils import check_if_valid_base64, encode_image_to_base64 if TYPE_CHECKING: from logging import LogRecord @@ -123,11 +123,13 @@ def create_message( cls, role: str = DEFAULT_ROLE, text: str | None = None, - images: list[np.ndarray | str] | None = None, + images: list[np.ndarray | str] | str | np.ndarray | None = None, ) -> Self: # Assume no images, and update to images if present content: str | list[dict] | None = text if images is not None: + if isinstance(images, str | np.ndarray): + images = [images] content = [ { "type": "image_url", @@ -135,7 +137,7 @@ def create_message( "url": encode_image_to_base64(image) # If image is a string, assume it's already a base64 encoded image if isinstance(image, np.ndarray) - else image + else check_if_valid_base64(image) }, } for image in images diff --git a/src/aviary/utils.py b/src/aviary/utils.py index 4d4fb3ef..e85fbaad 100644 --- a/src/aviary/utils.py +++ b/src/aviary/utils.py @@ -70,6 +70,16 @@ def encode_image_to_base64(img: "np.ndarray") -> str: ) +def check_if_valid_base64(image: str) -> str: + """Check if the input string is a valid base64 encoded image.""" + try: + base64.b64decode(image) + except Exception as err: + raise ValueError("Invalid base64 encoded image") from err + else: + return image + + def is_coroutine_callable(obj) -> bool: """Get if the input object is awaitable.""" if inspect.isfunction(obj) or inspect.ismethod(obj): diff --git a/tests/test_messages.py b/tests/test_messages.py index 8425c015..33a01394 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -112,22 +112,66 @@ def test_str(self, message: Message, expected: str) -> None: def test_dump(self, message: Message, expected: dict) -> None: assert message.model_dump(exclude_none=True) == expected - def test_image_message(self) -> None: - # An RGB image of a red square - red_square = np.zeros((32, 32, 3), dtype=np.uint8) - red_square[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB - - # A pre-encoded base64 image (simulated) - encoded_image = "data:image/jpeg;base64,fake_base64_content" - - message_text = "What color are these squares? List each color." - message_with_images = Message.create_message( - text=message_text, images=[red_square, encoded_image] - ) - + @pytest.mark.parametrize( + ("images", "message_text", "expected_error", "expected_content_length"), + [ + # Case 1: Invalid base64 image should raise error + ( + [ + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "data:image/jpeg;base64,fake_base64_content", # invalid base64 + ], + "What color are these squares? List each color.", + "Invalid base64 encoded image", + None, + ), + # Case 2: Valid images should work + ( + [ + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJYAlgDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCSAWCdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAf/2Q==", # valid base64 + ], + "What color are these squares? List each color.", + None, + 3, # 2 images + 1 text + ), + # Case 3: A numpy array in non-list formatshould be converted to a base64 encoded image + ( + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "What color is this square?", + None, + 2, # 1 image + 1 text + ), + # Case 4: A string should be converted to a base64 encoded image + ( + "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJYAlgDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCSAWCdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAf/2Q==", # valid base64 + "What color is this square?", + None, + 2, # 1 image + 1 text + ), + ], + ) + def test_image_message( + self, + images: list[np.ndarray | str] | np.ndarray | str, + message_text: str, + expected_error: str | None, + expected_content_length: int | None, + ) -> None: + # Set red color for numpy array if present + for img in images: + if isinstance(img, np.ndarray): + img[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB + + if expected_error: + with pytest.raises(ValueError, match=expected_error): + Message.create_message(text=message_text, images=images) + return + + message_with_images = Message.create_message(text=message_text, images=images) assert message_with_images.content specialized_content = json.loads(message_with_images.content) - assert len(specialized_content) == 3 # 2 images + 1 text + assert len(specialized_content) == expected_content_length # Find indices of each content type image_indices = [] @@ -138,7 +182,10 @@ def test_image_message(self) -> None: else: text_idx = i - assert len(image_indices) == 2 + if isinstance(images, list): + assert len(image_indices) == len(images) + else: + assert len(image_indices) == 1 assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text @@ -146,13 +193,10 @@ def test_image_message(self) -> None: for idx in image_indices: assert "image_url" in specialized_content[idx] assert "url" in specialized_content[idx]["image_url"] - # First image should be base64 encoded, second should be the raw string - if idx == image_indices[0]: - assert specialized_content[idx]["image_url"]["url"].startswith( - "data:image/" - ) - else: - assert specialized_content[idx]["image_url"]["url"] == encoded_image + # Both images should be base64 encoded + assert specialized_content[idx]["image_url"]["url"].startswith( + "data:image/" + ) class TestToolRequestMessage: From bab1c73156c79df5ca8c6878c4a613dcfcf78daf Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:32:54 -0800 Subject: [PATCH 07/13] Add fixtures --- tests/fixtures/test_images/sample_image1.b64 | 1 + tests/test_messages.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/test_images/sample_image1.b64 diff --git a/tests/fixtures/test_images/sample_image1.b64 b/tests/fixtures/test_images/sample_image1.b64 new file mode 100644 index 00000000..fdefbae2 --- /dev/null +++ b/tests/fixtures/test_images/sample_image1.b64 @@ -0,0 +1 @@ +data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJYAlgDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCSAWCdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAf/2Q== diff --git a/tests/test_messages.py b/tests/test_messages.py index 33a01394..6e787238 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,4 +1,5 @@ import json +import pathlib import numpy as np import pytest @@ -11,6 +12,12 @@ ToolResponseMessage, ) +FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" / "test_images" + + +def load_base64_image(filename: str) -> str: + return (FIXTURES_DIR / filename).read_text().strip() + class TestMessage: def test_roles(self) -> None: @@ -129,7 +136,7 @@ def test_dump(self, message: Message, expected: dict) -> None: ( [ np.zeros((32, 32, 3), dtype=np.uint8), # red square - "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJYAlgDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCSAWCdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAf/2Q==", # valid base64 + load_base64_image("sample_image1.b64"), ], "What color are these squares? List each color.", None, @@ -144,7 +151,7 @@ def test_dump(self, message: Message, expected: dict) -> None: ), # Case 4: A string should be converted to a base64 encoded image ( - "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJYAlgDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAb/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCSAWCdAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAf/2Q==", # valid base64 + load_base64_image("sample_image1.b64"), "What color is this square?", None, 2, # 1 image + 1 text From 6ba8ec0b77266d5e676f827a711ad989a6d0acaa Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:38:58 -0800 Subject: [PATCH 08/13] making ruff + pylint happy --- src/aviary/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/aviary/utils.py b/src/aviary/utils.py index e85fbaad..c4ecf38c 100644 --- a/src/aviary/utils.py +++ b/src/aviary/utils.py @@ -76,8 +76,7 @@ def check_if_valid_base64(image: str) -> str: base64.b64decode(image) except Exception as err: raise ValueError("Invalid base64 encoded image") from err - else: - return image + return image def is_coroutine_callable(obj) -> bool: From 29bdd091d14833ecbda994fd8d1df9c9fbcee140 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:54:57 -0800 Subject: [PATCH 09/13] bump versioning back up and add descriptive docstring --- .github/workflows/tests.yml | 4 ++-- src/aviary/message.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2553ee5f..b2ca22ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here strategy: matrix: - python-version: [3.11, 3.12] # Our min and max supported Python versions + python-version: [3.11, 3.13.0] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: @@ -27,7 +27,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.11, 3.12] # Our min and max supported Python versions + python-version: [3.11, 3.13.0] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: diff --git a/src/aviary/message.py b/src/aviary/message.py index 95440898..775d9cef 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -125,6 +125,18 @@ def create_message( text: str | None = None, images: list[np.ndarray | str] | str | np.ndarray | None = None, ) -> Self: + """Create a message with optional text and images. + + Args: + role: The role of the message. + text: The text of the message. + images: The images to include in the message. This can be a single image or + a list of images. Images can be a numpy array or a base64 encoded image + string (str). + + Returns: + The created message. + """ # Assume no images, and update to images if present content: str | list[dict] | None = text if images is not None: From 07a59003064fece45d7a838e4c6fb9a3c237e70c Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 15:30:51 -0800 Subject: [PATCH 10/13] Add numpy to dependencies --- pyproject.toml | 1 + uv.lock | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d014f63..74278a60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dependencies = [ "docstring_parser>=0.16", # Pin for description addition "httpx", + "numpy", "pydantic~=2.0", ] description = "Gymnasium framework for training language model agents on constructive tasks" diff --git a/uv.lock b/uv.lock index 6436548e..c6cc2f25 100644 --- a/uv.lock +++ b/uv.lock @@ -177,7 +177,7 @@ wheels = [ [[package]] name = "aviary-gsm8k" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev3+gaec552a.d20241213" source = { editable = "packages/gsm8k" } dependencies = [ { name = "datasets" }, @@ -200,7 +200,7 @@ requires-dist = [ [[package]] name = "aviary-hotpotqa" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev3+gaec552a.d20241213" source = { editable = "packages/hotpotqa" } dependencies = [ { name = "beautifulsoup4" }, @@ -825,11 +825,12 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.11.1.dev4+g1e8ea27.d20241204" +version = "0.12.1.dev10+g29bdd09.d20241213" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, { name = "httpx" }, + { name = "numpy" }, { name = "pydantic" }, ] @@ -980,6 +981,7 @@ requires-dist = [ { name = "litellm", marker = "python_full_version >= '3.13' and extra == 'llm'", specifier = ">=1.49.1" }, { name = "litellm", marker = "python_full_version < '3.13' and extra == 'llm'" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, + { name = "numpy" }, { name = "numpy", marker = "extra == 'typing'" }, { name = "paper-qa", extras = ["ldp"], marker = "extra == 'paperqa'", specifier = ">=5" }, { name = "pillow", marker = "extra == 'image'" }, From dc32b544f6f82082bf4c381e3a0188a72da4d193 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 15:55:25 -0800 Subject: [PATCH 11/13] trigger checks From e6d18a0ff3bba4e5baacf1cfe6bb24afdd47b95a Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 16:13:45 -0800 Subject: [PATCH 12/13] remove numpy and update type checks --- pyproject.toml | 1 - src/aviary/message.py | 6 +++--- uv.lock | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 74278a60..7d014f63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ dependencies = [ "docstring_parser>=0.16", # Pin for description addition "httpx", - "numpy", "pydantic~=2.0", ] description = "Gymnasium framework for training language model agents on constructive tasks" diff --git a/src/aviary/message.py b/src/aviary/message.py index 775d9cef..6a89ff8e 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -146,10 +146,10 @@ def create_message( { "type": "image_url", "image_url": { - "url": encode_image_to_base64(image) + "url": check_if_valid_base64(image) # If image is a string, assume it's already a base64 encoded image - if isinstance(image, np.ndarray) - else check_if_valid_base64(image) + if isinstance(image, str) + else encode_image_to_base64(image) }, } for image in images diff --git a/uv.lock b/uv.lock index c6cc2f25..5f2ebc80 100644 --- a/uv.lock +++ b/uv.lock @@ -825,12 +825,11 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.12.1.dev10+g29bdd09.d20241213" +version = "0.12.1.dev12+gdc32b54.d20241214" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, { name = "httpx" }, - { name = "numpy" }, { name = "pydantic" }, ] @@ -981,7 +980,6 @@ requires-dist = [ { name = "litellm", marker = "python_full_version >= '3.13' and extra == 'llm'", specifier = ">=1.49.1" }, { name = "litellm", marker = "python_full_version < '3.13' and extra == 'llm'" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, - { name = "numpy" }, { name = "numpy", marker = "extra == 'typing'" }, { name = "paper-qa", extras = ["ldp"], marker = "extra == 'paperqa'", specifier = ">=5" }, { name = "pillow", marker = "extra == 'image'" }, From e518abf3a75bcdb7ddaaf63227fdfa3f7f4f1cd4 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 16:24:47 -0800 Subject: [PATCH 13/13] final removal on numpy from logic --- src/aviary/message.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aviary/message.py b/src/aviary/message.py index 6a89ff8e..62923b9a 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, ClassVar, Self -import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator from aviary.utils import check_if_valid_base64, encode_image_to_base64 @@ -12,6 +11,8 @@ if TYPE_CHECKING: from logging import LogRecord + import numpy as np + class Message(BaseModel): DEFAULT_ROLE: ClassVar[str] = "user" @@ -140,7 +141,7 @@ def create_message( # Assume no images, and update to images if present content: str | list[dict] | None = text if images is not None: - if isinstance(images, str | np.ndarray): + if not isinstance(images, list): images = [images] content = [ {