Skip to content

Commit

Permalink
update bytestream to recieve png images
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlosFerLo committed Jul 23, 2024
1 parent 0a8184b commit 503de10
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
9 changes: 6 additions & 3 deletions haystack/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union


@dataclass
Expand Down Expand Up @@ -65,13 +65,16 @@ def to_string(self, encoding: str = "utf-8") -> str:
return self.data.decode(encoding)

@classmethod
def from_base64_image(cls, image: bytes, meta: Optional[Dict[str, Any]] = None) -> "ByteStream":
def from_base64_image(
cls, image: bytes, image_format: str = "jpg", meta: Optional[Dict[str, Any]] = None
) -> "ByteStream":
"""
Create a ByteStream containing a base64 image.
The 'mime_type' field will be populated with 'image_base64'.
:param image: The base 64 encoded image
:param image_format: The file extension of the image.
:param meta: Additional metadata to be stored with the ByteStream.
"""
return cls(data=image, mime_type="image_base64", meta=meta or {})
return cls(data=image, mime_type=f"image_base64/{image_format}", meta=meta or {})
11 changes: 9 additions & 2 deletions test/dataclasses/test_byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,14 @@ def encode_image(image_path):

base64_image = encode_image(test_files_path / "images" / "apple.jpg")

b = ByteStream.from_base64_image(base64_image, {"some": "some"})
b = ByteStream.from_base64_image(base64_image, meta={"some": "some"})
assert b.data == base64_image
assert b.mime_type == "image_base64"
assert b.mime_type == "image_base64/jpg"
assert b.meta == {"some": "some"}

base64_image = encode_image(test_files_path / "images" / "haystack-logo.png")

b = ByteStream.from_base64_image(base64_image, image_format="png", meta={"some": "some"})
assert b.data == base64_image
assert b.mime_type == "image_base64/png"
assert b.meta == {"some": "some"}

0 comments on commit 503de10

Please sign in to comment.