From 349289fbbd7192bbfab2cd03a650678b19438e95 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 11:24:14 -0800 Subject: [PATCH 01/13] message update + tests --- src/aviary/message.py | 17 +++++++++++----- tests/test_messages.py | 46 +++++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/src/aviary/message.py b/src/aviary/message.py index fce01a14..92c3081a 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -1,6 +1,8 @@ from __future__ import annotations import json +import numpy as np + from collections.abc import Iterable from typing import TYPE_CHECKING, ClassVar, Self @@ -11,7 +13,6 @@ if TYPE_CHECKING: from logging import LogRecord - import numpy as np class Message(BaseModel): @@ -124,16 +125,22 @@ def create_message( cls, role: str = DEFAULT_ROLE, text: str | None = None, - image: np.ndarray | None = None, + images: list[np.ndarray | str] | None = None, ) -> Self: - # Assume no image, and update to image if present + # Assume no images, and update to images if present content: str | list[dict] | None = text - if image is not None: + if images is not None: content = [ { "type": "image_url", - "image_url": {"url": encode_image_to_base64(image)}, + "image_url": { + "url": encode_image_to_base64(image) + # If image is a string, assume it's already a base64 encoded image + if isinstance(image, np.ndarray) + else image + }, } + for image in images ] if text is not None: content.append({"type": "text", "text": text}) diff --git a/tests/test_messages.py b/tests/test_messages.py index 10c6af54..f3a7c1b2 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -114,18 +114,44 @@ def test_dump(self, message: Message, expected: dict) -> None: def test_image_message(self) -> None: # An RGB image of a red square - image = np.zeros((32, 32, 3), dtype=np.uint8) - image[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB - message_text = "What color is this square? Respond only with the color name." - message_with_image = Message.create_message(text=message_text, image=image) - assert message_with_image.content - specialized_content = json.loads(message_with_image.content) - assert len(specialized_content) == 2 - text_idx, image_idx = ( - (0, 1) if specialized_content[0]["type"] == "text" else (1, 0) + red_square = np.zeros((32, 32, 3), dtype=np.uint8) + red_square[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB + + # A pre-encoded base64 image (simulated) + encoded_image = "_base64_content" + + message_text = "What color are these squares? List each color." + message_with_images = Message.create_message( + text=message_text, + images=[red_square, encoded_image] ) + + assert message_with_images.content + specialized_content = json.loads(message_with_images.content) + assert len(specialized_content) == 3 # 2 images + 1 text + + # Find indices of each content type + image_indices = [] + text_idx = None + for i, content in enumerate(specialized_content): + if content["type"] == "image_url": + image_indices.append(i) + else: + text_idx = i + + assert len(image_indices) == 2 + assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text - assert "image_url" in specialized_content[image_idx] + + # Check both images are properly formatted + for idx in image_indices: + assert "image_url" in specialized_content[idx] + assert "url" in specialized_content[idx]["image_url"] + # First image should be base64 encoded, second should be the raw string + if idx == image_indices[0]: + assert specialized_content[idx]["image_url"]["url"].startswith("_base64_content" - + message_text = "What color are these squares? List each color." message_with_images = Message.create_message( - text=message_text, - images=[red_square, encoded_image] + text=message_text, images=[red_square, encoded_image] ) - + assert message_with_images.content specialized_content = json.loads(message_with_images.content) assert len(specialized_content) == 3 # 2 images + 1 text - + # Find indices of each content type image_indices = [] text_idx = None @@ -138,18 +137,20 @@ def test_image_message(self) -> None: image_indices.append(i) else: text_idx = i - + assert len(image_indices) == 2 assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text - + # Check both images are properly formatted for idx in image_indices: assert "image_url" in specialized_content[idx] assert "url" in specialized_content[idx]["image_url"] # First image should be base64 encoded, second should be the raw string if idx == image_indices[0]: - assert specialized_content[idx]["image_url"]["url"].startswith("_base64_content" - - message_text = "What color are these squares? List each color." - message_with_images = Message.create_message( - text=message_text, images=[red_square, encoded_image] - ) - + @pytest.mark.parametrize( + ("images", "message_text", "expected_error", "expected_content_length"), + [ + # Case 1: Invalid base64 image should raise error + ( + [ + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "_base64_content", # invalid base64 + ], + "What color are these squares? List each color.", + "Invalid base64 encoded image", + None, + ), + # Case 2: Valid images should work + ( + [ + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "", # valid base64 + ], + "What color are these squares? List each color.", + None, + 3, # 2 images + 1 text + ), + # Case 3: A numpy array in non-list formatshould be converted to a base64 encoded image + ( + np.zeros((32, 32, 3), dtype=np.uint8), # red square + "What color is this square?", + None, + 2, # 1 image + 1 text + ), + # Case 4: A string should be converted to a base64 encoded image + ( + "", # valid base64 + "What color is this square?", + None, + 2, # 1 image + 1 text + ), + ], + ) + def test_image_message( + self, + images: list[np.ndarray | str] | np.ndarray | str, + message_text: str, + expected_error: str | None, + expected_content_length: int | None, + ) -> None: + # Set red color for numpy array if present + for img in images: + if isinstance(img, np.ndarray): + img[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB + + if expected_error: + with pytest.raises(ValueError, match=expected_error): + Message.create_message(text=message_text, images=images) + return + + message_with_images = Message.create_message(text=message_text, images=images) assert message_with_images.content specialized_content = json.loads(message_with_images.content) - assert len(specialized_content) == 3 # 2 images + 1 text + assert len(specialized_content) == expected_content_length # Find indices of each content type image_indices = [] @@ -138,7 +182,10 @@ def test_image_message(self) -> None: else: text_idx = i - assert len(image_indices) == 2 + if isinstance(images, list): + assert len(image_indices) == len(images) + else: + assert len(image_indices) == 1 assert text_idx is not None assert specialized_content[text_idx]["text"] == message_text @@ -146,13 +193,10 @@ def test_image_message(self) -> None: for idx in image_indices: assert "image_url" in specialized_content[idx] assert "url" in specialized_content[idx]["image_url"] - # First image should be base64 encoded, second should be the raw string - if idx == image_indices[0]: - assert specialized_content[idx]["image_url"]["url"].startswith( - " diff --git a/tests/test_messages.py b/tests/test_messages.py index 33a01394..6e787238 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,4 +1,5 @@ import json +import pathlib import numpy as np import pytest @@ -11,6 +12,12 @@ ToolResponseMessage, ) +FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" / "test_images" + + +def load_base64_image(filename: str) -> str: + return (FIXTURES_DIR / filename).read_text().strip() + class TestMessage: def test_roles(self) -> None: @@ -129,7 +136,7 @@ def test_dump(self, message: Message, expected: dict) -> None: ( [ np.zeros((32, 32, 3), dtype=np.uint8), # red square - "", # valid base64 + load_base64_image("sample_image1.b64"), ], "What color are these squares? List each color.", None, @@ -144,7 +151,7 @@ def test_dump(self, message: Message, expected: dict) -> None: ), # Case 4: A string should be converted to a base64 encoded image ( - "", # valid base64 + load_base64_image("sample_image1.b64"), "What color is this square?", None, 2, # 1 image + 1 text From 6ba8ec0b77266d5e676f827a711ad989a6d0acaa Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:38:58 -0800 Subject: [PATCH 08/13] making ruff + pylint happy --- src/aviary/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/aviary/utils.py b/src/aviary/utils.py index e85fbaad..c4ecf38c 100644 --- a/src/aviary/utils.py +++ b/src/aviary/utils.py @@ -76,8 +76,7 @@ def check_if_valid_base64(image: str) -> str: base64.b64decode(image) except Exception as err: raise ValueError("Invalid base64 encoded image") from err - else: - return image + return image def is_coroutine_callable(obj) -> bool: From 29bdd091d14833ecbda994fd8d1df9c9fbcee140 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 14:54:57 -0800 Subject: [PATCH 09/13] bump versioning back up and add descriptive docstring --- .github/workflows/tests.yml | 4 ++-- src/aviary/message.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2553ee5f..b2ca22ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here strategy: matrix: - python-version: [3.11, 3.12] # Our min and max supported Python versions + python-version: [3.11, 3.13.0] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: @@ -27,7 +27,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.11, 3.12] # Our min and max supported Python versions + python-version: [3.11, 3.13.0] # Our min and max supported Python versions steps: - uses: actions/checkout@v4 with: diff --git a/src/aviary/message.py b/src/aviary/message.py index 95440898..775d9cef 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -125,6 +125,18 @@ def create_message( text: str | None = None, images: list[np.ndarray | str] | str | np.ndarray | None = None, ) -> Self: + """Create a message with optional text and images. + + Args: + role: The role of the message. + text: The text of the message. + images: The images to include in the message. This can be a single image or + a list of images. Images can be a numpy array or a base64 encoded image + string (str). + + Returns: + The created message. + """ # Assume no images, and update to images if present content: str | list[dict] | None = text if images is not None: From 07a59003064fece45d7a838e4c6fb9a3c237e70c Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 15:30:51 -0800 Subject: [PATCH 10/13] Add numpy to dependencies --- pyproject.toml | 1 + uv.lock | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d014f63..74278a60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dependencies = [ "docstring_parser>=0.16", # Pin for description addition "httpx", + "numpy", "pydantic~=2.0", ] description = "Gymnasium framework for training language model agents on constructive tasks" diff --git a/uv.lock b/uv.lock index 6436548e..c6cc2f25 100644 --- a/uv.lock +++ b/uv.lock @@ -177,7 +177,7 @@ wheels = [ [[package]] name = "aviary-gsm8k" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev3+gaec552a.d20241213" source = { editable = "packages/gsm8k" } dependencies = [ { name = "datasets" }, @@ -200,7 +200,7 @@ requires-dist = [ [[package]] name = "aviary-hotpotqa" -version = "0.11.1.dev3+gef1e2d5.d20241204" +version = "0.12.1.dev3+gaec552a.d20241213" source = { editable = "packages/hotpotqa" } dependencies = [ { name = "beautifulsoup4" }, @@ -825,11 +825,12 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.11.1.dev4+g1e8ea27.d20241204" +version = "0.12.1.dev10+g29bdd09.d20241213" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, { name = "httpx" }, + { name = "numpy" }, { name = "pydantic" }, ] @@ -980,6 +981,7 @@ requires-dist = [ { name = "litellm", marker = "python_full_version >= '3.13' and extra == 'llm'", specifier = ">=1.49.1" }, { name = "litellm", marker = "python_full_version < '3.13' and extra == 'llm'" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, + { name = "numpy" }, { name = "numpy", marker = "extra == 'typing'" }, { name = "paper-qa", extras = ["ldp"], marker = "extra == 'paperqa'", specifier = ">=5" }, { name = "pillow", marker = "extra == 'image'" }, From dc32b544f6f82082bf4c381e3a0188a72da4d193 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 15:55:25 -0800 Subject: [PATCH 11/13] trigger checks From e6d18a0ff3bba4e5baacf1cfe6bb24afdd47b95a Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 16:13:45 -0800 Subject: [PATCH 12/13] remove numpy and update type checks --- pyproject.toml | 1 - src/aviary/message.py | 6 +++--- uv.lock | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 74278a60..7d014f63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ dependencies = [ "docstring_parser>=0.16", # Pin for description addition "httpx", - "numpy", "pydantic~=2.0", ] description = "Gymnasium framework for training language model agents on constructive tasks" diff --git a/src/aviary/message.py b/src/aviary/message.py index 775d9cef..6a89ff8e 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -146,10 +146,10 @@ def create_message( { "type": "image_url", "image_url": { - "url": encode_image_to_base64(image) + "url": check_if_valid_base64(image) # If image is a string, assume it's already a base64 encoded image - if isinstance(image, np.ndarray) - else check_if_valid_base64(image) + if isinstance(image, str) + else encode_image_to_base64(image) }, } for image in images diff --git a/uv.lock b/uv.lock index c6cc2f25..5f2ebc80 100644 --- a/uv.lock +++ b/uv.lock @@ -825,12 +825,11 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.12.1.dev10+g29bdd09.d20241213" +version = "0.12.1.dev12+gdc32b54.d20241214" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, { name = "httpx" }, - { name = "numpy" }, { name = "pydantic" }, ] @@ -981,7 +980,6 @@ requires-dist = [ { name = "litellm", marker = "python_full_version >= '3.13' and extra == 'llm'", specifier = ">=1.49.1" }, { name = "litellm", marker = "python_full_version < '3.13' and extra == 'llm'" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, - { name = "numpy" }, { name = "numpy", marker = "extra == 'typing'" }, { name = "paper-qa", extras = ["ldp"], marker = "extra == 'paperqa'", specifier = ">=5" }, { name = "pillow", marker = "extra == 'image'" }, From e518abf3a75bcdb7ddaaf63227fdfa3f7f4f1cd4 Mon Sep 17 00:00:00 2001 From: Ludovico Mitchener Date: Fri, 13 Dec 2024 16:24:47 -0800 Subject: [PATCH 13/13] final removal on numpy from logic --- src/aviary/message.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aviary/message.py b/src/aviary/message.py index 6a89ff8e..62923b9a 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, ClassVar, Self -import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator from aviary.utils import check_if_valid_base64, encode_image_to_base64 @@ -12,6 +11,8 @@ if TYPE_CHECKING: from logging import LogRecord + import numpy as np + class Message(BaseModel): DEFAULT_ROLE: ClassVar[str] = "user" @@ -140,7 +141,7 @@ def create_message( # Assume no images, and update to images if present content: str | list[dict] | None = text if images is not None: - if isinstance(images, str | np.ndarray): + if not isinstance(images, list): images = [images] content = [ {