From 503de10b0fc2cae14e677568151f24bef2c5bf02 Mon Sep 17 00:00:00 2001 From: Carlos Fernandez Date: Tue, 23 Jul 2024 22:34:51 +0200 Subject: [PATCH] update bytestream to recieve png images --- haystack/dataclasses/byte_stream.py | 9 ++++++--- test/dataclasses/test_byte_stream.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/haystack/dataclasses/byte_stream.py b/haystack/dataclasses/byte_stream.py index d0398c3790..52ca6e388f 100644 --- a/haystack/dataclasses/byte_stream.py +++ b/haystack/dataclasses/byte_stream.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union @dataclass @@ -65,13 +65,16 @@ def to_string(self, encoding: str = "utf-8") -> str: return self.data.decode(encoding) @classmethod - def from_base64_image(cls, image: bytes, meta: Optional[Dict[str, Any]] = None) -> "ByteStream": + def from_base64_image( + cls, image: bytes, image_format: str = "jpg", meta: Optional[Dict[str, Any]] = None + ) -> "ByteStream": """ Create a ByteStream containing a base64 image. The 'mime_type' field will be populated with 'image_base64'. :param image: The base 64 encoded image + :param image_format: The file extension of the image. :param meta: Additional metadata to be stored with the ByteStream. """ - return cls(data=image, mime_type="image_base64", meta=meta or {}) + return cls(data=image, mime_type=f"image_base64/{image_format}", meta=meta or {}) diff --git a/test/dataclasses/test_byte_stream.py b/test/dataclasses/test_byte_stream.py index 3717f86def..af3586d507 100644 --- a/test/dataclasses/test_byte_stream.py +++ b/test/dataclasses/test_byte_stream.py @@ -82,7 +82,14 @@ def encode_image(image_path): base64_image = encode_image(test_files_path / "images" / "apple.jpg") - b = ByteStream.from_base64_image(base64_image, {"some": "some"}) + b = ByteStream.from_base64_image(base64_image, meta={"some": "some"}) assert b.data == base64_image - assert b.mime_type == "image_base64" + assert b.mime_type == "image_base64/jpg" + assert b.meta == {"some": "some"} + + base64_image = encode_image(test_files_path / "images" / "haystack-logo.png") + + b = ByteStream.from_base64_image(base64_image, image_format="png", meta={"some": "some"}) + assert b.data == base64_image + assert b.mime_type == "image_base64/png" assert b.meta == {"some": "some"}