Skip to content

Commit

Permalink
Merge branch 'main' into extract-answer
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead authored Dec 18, 2024
2 parents d7ef73a + ef735a5 commit c6458f9
Show file tree
Hide file tree
Showing 14 changed files with 656 additions and 445 deletions.
6 changes: 5 additions & 1 deletion .github/renovate.json5
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
prHourlyLimit: 4,
timezone: "America/Los_Angeles",
rangeStrategy: "widen",
lockFileMaintenance: { enabled: true },
lockFileMaintenance: {
enabled: true,
schedule: ["* 2 1-7 * 1"], // Work around https://github.com/renovatebot/renovate/discussions/33152
},
minimumReleaseAge: "2 weeks",
"pre-commit": { enabled: true },
packageRules: [
{
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here
strategy:
matrix:
python-version: [3.11, 3.13] # Our min and max supported Python versions
python-version: [3.11, 3.13.0] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.11, 3.13] # Our min and max supported Python versions
python-version: [3.11, 3.13.0] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
with:
Expand Down
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ James Braza <james@futurehouse.org> <jamesbraza@gmail.com>
Michael Skarlinski <mskarlinski@futurehouse.org> mskarlin <12701035+mskarlin@users.noreply.github.com>
Ryan-Rhys Griffiths <ryan@futurehouse.org> <ryangriff123@gmail.com>
Siddharth Narayanan <sid@futurehouse.org> <sidnarayanan@users.noreply.github.com>
Ludovico Mitchener <ludo@futurehouse.org> ludomitch
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ repos:
- id: codespell
additional_dependencies: [".[toml]"]
exclude_types: [jupyter]
exclude: '.*\.b64$'
- repo: https://github.com/pappasam/toml-sort
rev: v0.24.2
hooks:
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.13
3.12
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,30 @@ Messages have two attributes:
msg = Message(content="Hello, world!", role="assistant")
```

The `content` is a string with a text, a JSON serializable list of `dict`s, or a null value.
A list of dicts is used to encode multi-modal content. The method `create_message` can be used to create a message with images:

```py
from PIL import Image
import numpy as np

img = Image.open("your_image.jpg")
img_array = np.array(img)

msg = Message.create_message(role="user", text="Hello, world!", images=[img_array])
```

`create_message` supports images as numpy array or base64 encoded images. In this case, `content` will be a list of dictionaries with the keys `text` and `image_url`.

```py
{
{"type": "text", "text": "Hello World!"},
{"text": "image_url", "image_url": "data:image/png;base64,{base64_image}"},
}
```

We follow the structure adopted by [OpenAI](https://platform.openai.com/docs/guides/vision?lang=node#uploading-base64-encoded-images).

For the meaning of role, see the table below.
You can change around roles as desired,
except for `tool` which has a special meaning in aviary.
Expand All @@ -73,7 +97,7 @@ except for `tool` which has a special meaning in aviary.
| user | Environment system prompt or emitted observation | HotPotQA problem to solve, or details of an internal env failure |
| tool | Result of tool run in the environment | Some number crunching program's output |

The `content` is a string that can be anything, or a null value.
`Message` is extended in `ToolRequestMessage` and `ToolResponseMessage` to include the relevant tool name and arguments.

## Environment

Expand All @@ -94,7 +118,7 @@ is a boolean value.
The easiest way to create an environment is using the functional interface, which just uses functions and decorators to define environments. First, let's define what the environment looks like by defining its `start` function:

```py
from aviary import fenv
from aviary.core import fenv


@fenv.start()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dev = [
"pytest>=8", # Pin to keep recent
"refurb>=2", # Pin to keep recent
"typeguard",
"vcrpy>=6", # Pin for https://github.com/kevin1024/vcrpy/issues/884
]
gsm8k = ["aviary.gsm8k"]
hotpotqa = ["aviary.hotpotqa"]
Expand Down
30 changes: 25 additions & 5 deletions src/aviary/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel, Field, field_validator, model_validator

from aviary.utils import encode_image_to_base64
from aviary.utils import encode_image_to_base64, validate_base64_image

if TYPE_CHECKING:
from logging import LogRecord
Expand Down Expand Up @@ -124,16 +124,36 @@ def create_message(
cls,
role: str = DEFAULT_ROLE,
text: str | None = None,
image: np.ndarray | None = None,
images: list[np.ndarray | str] | str | np.ndarray | None = None,
) -> Self:
# Assume no image, and update to image if present
"""Create a message with optional text and images.
Args:
role: The role of the message.
text: The text of the message.
images: The images to include in the message. This can be a single image or
a list of images. Images can be a numpy array or a base64 encoded image
string (str).
Returns:
The created message.
"""
# Assume no images, and update to images if present
content: str | list[dict] | None = text
if image is not None:
if images is not None:
if not isinstance(images, list):
images = [images]
content = [
{
"type": "image_url",
"image_url": {"url": encode_image_to_base64(image)},
"image_url": {
"url": validate_base64_image(image)
# If image is a string, assume it's already a base64 encoded image
if isinstance(image, str)
else encode_image_to_base64(image)
},
}
for image in images
]
if text is not None:
content.append({"type": "text", "text": text})
Expand Down
18 changes: 14 additions & 4 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
}


LLM_SCORE_EVAL_CONFIG = {
LLM_SCORE_EVAL_CONFIG = LLM_EVAL_CONFIG | {
"prompt": (
"Here is a question, the correct answer to the question, and a rubric for"
" evaluating the question. Judge the proposed answer based on the given rubric."
Expand All @@ -46,8 +46,6 @@
"\n\nRubric: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
"max_score": 10,
}

Expand Down Expand Up @@ -85,6 +83,17 @@ def encode_image_to_base64(img: "np.ndarray") -> str:
)


def validate_base64_image(image: str) -> str:
"""Validate if the input string is a valid base64 encoded image and if it is, return the image."""
try:
# Support for inclusion of the 
1 change: 1 addition & 0 deletions tests/fixtures/test_images/sample_png_image.b64

Large diffs are not rendered by default.

111 changes: 98 additions & 13 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pathlib

import numpy as np
import pytest
Expand All @@ -11,6 +12,12 @@
ToolResponseMessage,
)

FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" / "test_images"


def load_base64_image(filename: str) -> str:
return (FIXTURES_DIR / filename).read_text().strip()


class TestMessage:
def test_roles(self) -> None:
Expand Down Expand Up @@ -112,20 +119,98 @@ def test_str(self, message: Message, expected: str) -> None:
def test_dump(self, message: Message, expected: dict) -> None:
assert message.model_dump(exclude_none=True) == expected

def test_image_message(self) -> None:
# An RGB image of a red square
image = np.zeros((32, 32, 3), dtype=np.uint8)
image[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB
message_text = "What color is this square? Respond only with the color name."
message_with_image = Message.create_message(text=message_text, image=image)
assert message_with_image.content
specialized_content = json.loads(message_with_image.content)
assert len(specialized_content) == 2
text_idx, image_idx = (
(0, 1) if specialized_content[0]["type"] == "text" else (1, 0)
)
@pytest.mark.parametrize(
("images", "message_text", "expected_error", "expected_content_length"),
[
# Case 1: Invalid base64 image should raise error
(
[
np.zeros((32, 32, 3), dtype=np.uint8), # red square
"_base64_content", # invalid base64
],
"What color are these squares? List each color.",
"Invalid base64 encoded image",
None,
),
# Case 2: Valid images should work
(
[
np.zeros((32, 32, 3), dtype=np.uint8), # red square
load_base64_image("sample_jpeg_image.b64"),
],
"What color are these squares? List each color.",
None,
3, # 2 images + 1 text
),
# Case 3: A numpy array in non-list formatshould be converted to a base64 encoded image
(
np.zeros((32, 32, 3), dtype=np.uint8), # red square
"What color is this square?",
None,
2, # 1 image + 1 text
),
# Case 4: A string should be converted to a base64 encoded image
(
load_base64_image("sample_jpeg_image.b64"),
"What color is this square?",
None,
2, # 1 image + 1 text
),
# Case 5: A PNG image should be converted to a base64 encoded image
(
load_base64_image("sample_png_image.b64"),
"What color is this square?",
None,
2, # 1 image + 1 text
),
],
)
def test_image_message(
self,
images: list[np.ndarray | str] | np.ndarray | str,
message_text: str,
expected_error: str | None,
expected_content_length: int | None,
) -> None:
# Set red color for numpy array if present
for img in images:
if isinstance(img, np.ndarray):
img[:] = [255, 0, 0] # (255 red, 0 green, 0 blue) is maximum red in RGB

if expected_error:
with pytest.raises(ValueError, match=expected_error):
Message.create_message(text=message_text, images=images)
return

message_with_images = Message.create_message(text=message_text, images=images)
assert message_with_images.content
specialized_content = json.loads(message_with_images.content)
assert len(specialized_content) == expected_content_length

# Find indices of each content type
image_indices = []
text_idx = None
for i, content in enumerate(specialized_content):
if content["type"] == "image_url":
image_indices.append(i)
else:
text_idx = i

if isinstance(images, list):
assert len(image_indices) == len(images)
else:
assert len(image_indices) == 1
assert text_idx is not None
assert specialized_content[text_idx]["text"] == message_text
assert "image_url" in specialized_content[image_idx]

# Check both images are properly formatted
for idx in image_indices:
assert "image_url" in specialized_content[idx]
assert "url" in specialized_content[idx]["image_url"]
# Both images should be base64 encoded
assert specialized_content[idx]["image_url"]["url"].startswith(
"data:image/"
)


class TestToolRequestMessage:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
],
)
@pytest.mark.asyncio
async def test_eval_answer(proposed, correct, question, eval_mode, expected):
async def test_eval_answer(
proposed: str, correct: str, question: str | None, eval_mode: str, expected: float
) -> None:
assert await eval_answer(proposed, correct, question, eval_mode) == expected


Expand Down
Loading

0 comments on commit c6458f9

Please sign in to comment.