Skip to content

Commit

Permalink
add jpg and png support on ChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlosFerLo committed Jul 23, 2024
1 parent 503de10 commit 2a8a707
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
24 changes: 16 additions & 8 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ class ContentType(str, Enum):

TEXT = "text"
IMAGE_URL = "image_url"
IMAGE_BASE64 = "image_base64"
IMAGE_BASE64_JPG = "image_base64/jpg"
IMAGE_BASE64_PNG = "image_base64/png"

@staticmethod
def valid_byte_stream_types() -> List["ContentType"]:
"""Returns a list of all the valid types of represented by a ByteStream."""
return [ContentType.IMAGE_BASE64]
return [ContentType.IMAGE_BASE64_JPG, ContentType.IMAGE_BASE64_PNG]

def is_valid_byte_stream_type(self) -> bool:
"""Returns whether the type is a valid type for a ByteStream."""
Expand Down Expand Up @@ -156,22 +157,29 @@ def to_openai_format(self) -> Dict[str, Any]:
content.append({"type": "text", "text": part})
elif type_ is ContentType.IMAGE_URL and isinstance(part, str):
content.append({"type": "image_url", "image_url": {"url": part}})
elif type_ is ContentType.IMAGE_BASE64 and isinstance(part, ByteStream):
elif type_ is ContentType.IMAGE_BASE64_JPG and isinstance(part, ByteStream):
content.append(
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{part.to_string()}"}}
)
elif type_ is ContentType.IMAGE_BASE64_PNG and isinstance(part, ByteStream):
content.append(
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{part.to_string()}"}}
)
else:
raise ValueError("The content types stored at metadata '__haystack_content_type__' was corrupted.")
else:
if types is ContentType.TEXT and isinstance(self.content, str):
content: str = self.content
elif types is ContentType.IMAGE_URL and isinstance(self.content, str):
content = [{"type": "image_url", "image_url": {"url": self.content}}]
elif types is ContentType.IMAGE_BASE64 and isinstance(self.content, ByteStream):
content = [{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{self.content.to_string()}"},
}]
elif types is ContentType.IMAGE_BASE64_JPG and isinstance(self.content, ByteStream):
content = [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{self.content.to_string()}"}}
]
elif types is ContentType.IMAGE_BASE64_PNG and isinstance(self.content, ByteStream):
content = [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{self.content.to_string()}"}}
]
else:
raise ValueError("The content types stored at metadata '__haystack_content_type__' was corrupted.")

Expand Down
52 changes: 38 additions & 14 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def test_to_openai_format_with_multimodal_content():
"content": [{"type": "image_url", "image_url": {"url": ""}}],
}

message = ChatMessage.from_user(ByteStream.from_base64_image(b"image", image_format="png"))
assert message.to_openai_format() == {
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": ""}}],
}

message = ChatMessage.from_assistant(
["this is text", "image_url:images.com/test.jpg", ByteStream.from_base64_image(b"IMAGE")]
)
Expand Down Expand Up @@ -123,22 +129,22 @@ def test_to_dict():
"meta": {},
}

message = ChatMessage.from_system(ByteStream(b"bytes", mime_type="image_base64"))
message = ChatMessage.from_system(ByteStream(b"bytes", mime_type="image_base64/jpg"))
assert message.to_dict() == {
"content": {"data": b"bytes", "mime_type": "image_base64", "meta": {}},
"content": {"data": b"bytes", "mime_type": "image_base64/jpg", "meta": {}},
"role": "system",
"name": None,
"meta": {},
}

message = ChatMessage.from_user(
content=["string content", "image_url:images.com/test.jpg", ByteStream(b"bytes", mime_type="image_base64")]
content=["string content", "image_url:images.com/test.jpg", ByteStream(b"bytes", mime_type="image_base64/png")]
)
assert message.to_dict() == {
"content": [
"string content",
"image_url:images.com/test.jpg",
{"data": b"bytes", "mime_type": "image_base64", "meta": {}},
{"data": b"bytes", "mime_type": "image_base64/png", "meta": {}},
],
"role": "user",
"name": None,
Expand All @@ -157,27 +163,27 @@ def test_from_dict():

assert ChatMessage.from_dict(
data={
"content": {"data": b"bytes", "mime_type": "image_base64", "meta": {}},
"content": {"data": b"bytes", "mime_type": "image_base64/png", "meta": {}},
"role": "user",
"name": None,
"meta": {},
}
) == ChatMessage(
content=ByteStream(data=b"bytes", mime_type="image_base64"), role=ChatRole.USER, name=None, meta={}
content=ByteStream(data=b"bytes", mime_type="image_base64/png"), role=ChatRole.USER, name=None, meta={}
)
assert ChatMessage.from_dict(
data={
"content": [
"string content",
"image_url:images.com/test.jpg",
{"data": b"bytes", "mime_type": "image_base64", "meta": {}},
{"data": b"bytes", "mime_type": "image_base64/jpg", "meta": {}},
],
"role": "user",
"name": None,
"meta": {},
}
) == ChatMessage(
content=["string content", "image_url:images.com/test.jpg", ByteStream(b"bytes", mime_type="image_base64")],
content=["string content", "image_url:images.com/test.jpg", ByteStream(b"bytes", mime_type="image_base64/jpg")],
role=ChatRole.USER,
name=None,
meta={},
Expand All @@ -201,18 +207,36 @@ def test_post_init_method():
assert "__haystack_content_type__" in message.meta
assert message.meta["__haystack_content_type__"] == ContentType.IMAGE_URL

message = ChatMessage.from_system(ByteStream(data=b"content", mime_type="image_base64"))
assert message.content == ByteStream(data=b"content", mime_type="image_base64")
message = ChatMessage.from_system(ByteStream(data=b"content", mime_type="image_base64/jpg"))
assert message.content == ByteStream(data=b"content", mime_type="image_base64/jpg")
assert "__haystack_content_type__" in message.meta
assert message.meta["__haystack_content_type__"] == ContentType.IMAGE_BASE64
assert message.meta["__haystack_content_type__"] == ContentType.IMAGE_BASE64_JPG

message = ChatMessage.from_user(["content", "image_url:{{url}}", ByteStream(b"content", mime_type="image_base64")])
assert message.content == ["content", "{{url}}", ByteStream(b"content", mime_type="image_base64")]
message = ChatMessage.from_system(ByteStream(data=b"content", mime_type="image_base64/png"))
assert message.content == ByteStream(data=b"content", mime_type="image_base64/png")
assert "__haystack_content_type__" in message.meta
assert message.meta["__haystack_content_type__"] == ContentType.IMAGE_BASE64_PNG

message = ChatMessage.from_user(
[
"content",
"image_url:{{url}}",
ByteStream(b"content", mime_type="image_base64/jpg"),
ByteStream(b"content", mime_type="image_base64/png"),
]
)
assert message.content == [
"content",
"{{url}}",
ByteStream(b"content", mime_type="image_base64/jpg"),
ByteStream(b"content", mime_type="image_base64/png"),
]
assert "__haystack_content_type__" in message.meta
assert message.meta["__haystack_content_type__"] == [
ContentType.TEXT,
ContentType.IMAGE_URL,
ContentType.IMAGE_BASE64,
ContentType.IMAGE_BASE64_JPG,
ContentType.IMAGE_BASE64_PNG,
]


Expand Down

0 comments on commit 2a8a707

Please sign in to comment.