From 9dcaa6b12eaa2e5515d400bd0632461db852a683 Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Fri, 13 Dec 2024 15:45:09 +0800 Subject: [PATCH] feat: ByteStream auto mime_type detection and base64 (de)encoding --- haystack_experimental/dataclasses/__init__.py | 7 +- .../dataclasses/byte_stream.py | 152 ++++++++++++++++++ test/dataclasses/__init__.py | 0 test/dataclasses/test_byte_stream.py | 91 +++++++++++ 4 files changed, 248 insertions(+), 2 deletions(-) create mode 100644 haystack_experimental/dataclasses/byte_stream.py create mode 100644 test/dataclasses/__init__.py create mode 100644 test/dataclasses/test_byte_stream.py diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index 78a97618..17bdeb27 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.dataclasses.byte_stream import ByteStream from haystack_experimental.dataclasses.chat_message import ( ChatMessage, ChatMessageContentT, @@ -18,12 +19,14 @@ __all__ = [ "AsyncStreamingCallbackT", + "ByteStream", "ChatMessage", + "ChatMessageContentT", "ChatRole", + "MediaContent", "StreamingCallbackT", + "TextContent", "ToolCall", "ToolCallResult", - "TextContent", - "ChatMessageContentT", "Tool", ] diff --git a/haystack_experimental/dataclasses/byte_stream.py b/haystack_experimental/dataclasses/byte_stream.py new file mode 100644 index 00000000..268b3161 --- /dev/null +++ b/haystack_experimental/dataclasses/byte_stream.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +""" +Data classes for representing binary data in the Haystack API. The ByteStream class can be used to represent binary data +in the API, and can be converted to and from base64 encoded strings, dictionaries, and files. This is particularly +useful for representing media files in chat messages. +""" + +import logging +import mimetypes +from base64 import b64encode, b64decode +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional + + +logger = logging.getLogger(__name__) + + +@dataclass +class ByteStream: + """ + Base data class representing a binary object in the Haystack API. + """ + + data: bytes + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + mime_type: Optional[str] = field(default=None) + + @property + def type(self) -> Optional[str]: + """ + Return the type of the ByteStream. This is the first part of the mime type, or None if the mime type is not set. + + :return: The type of the ByteStream. + """ + if self.mime_type: + return self.mime_type.split("/", maxsplit=1)[0] + return None + + @property + def subtype(self) -> Optional[str]: + """ + Return the subtype of the ByteStream. This is the second part of the mime type, + or None if the mime type is not set. + + :return: The subtype of the ByteStream. + """ + if self.mime_type: + return self.mime_type.split("/", maxsplit=1)[-1] + return None + + def to_file(self, destination_path: Path): + """ + Write the ByteStream to a file. Note: the metadata will be lost. + + :param destination_path: The path to write the ByteStream to. + """ + with open(destination_path, "wb") as fd: + fd.write(self.data) + + @classmethod + def from_file_path( + cls, filepath: Path, mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None + ) -> "ByteStream": + """ + Create a ByteStream from the contents read from a file. + + :param filepath: A valid path to a file. + :param mime_type: The mime type of the file. + :param meta: Additional metadata to be stored with the ByteStream. + """ + if mime_type is None: + mime_type = mimetypes.guess_type(filepath)[0] + if mime_type is None: + logger.warning("Could not determine mime type for file %s", filepath) + + with open(filepath, "rb") as fd: + return cls(data=fd.read(), mime_type=mime_type, meta=meta or {}) + + @classmethod + def from_string( + cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None + ) -> "ByteStream": + """ + Create a ByteStream encoding a string. + + :param text: The string to encode + :param encoding: The encoding used to convert the string into bytes + :param mime_type: The mime type of the file. + :param meta: Additional metadata to be stored with the ByteStream. + """ + return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {}) + + def to_string(self, encoding: str = "utf-8") -> str: + """ + Convert the ByteStream to a string, metadata will not be included. + + :param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8". + :returns: The string representation of the ByteStream. + :raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. + """ + return self.data.decode(encoding) + + @classmethod + def from_base64( + cls, + base64_string: str, + encoding: str = "utf-8", + meta: Optional[Dict[str, Any]] = None, + mime_type: Optional[str] = None, + ) -> "ByteStream": + """ + Create a ByteStream from a base64 encoded string. + + :param base64_string: The base64 encoded string representation of the ByteStream data. + :param encoding: The encoding used to convert the base64 string into bytes. + :param meta: Additional metadata to be stored with the ByteStream. + :param mime_type: The mime type of the file. + :returns: A new ByteStream instance. + """ + return cls(data=b64decode(base64_string.encode(encoding)), meta=meta or {}, mime_type=mime_type) + + def to_base64(self, encoding: str = "utf-8") -> str: + """ + Convert the ByteStream data to a base64 encoded string. + + :returns: The base64 encoded string representation of the ByteStream data. + """ + return b64encode(self.data).decode(encoding) + + @classmethod + def from_dict(cls, data: Dict[str, Any], encoding: str = "utf-8") -> "ByteStream": + """ + Create a ByteStream from a dictionary. + + :param data: The dictionary representation of the ByteStream. + :param encoding: The encoding used to convert the base64 string into bytes. + :returns: A new ByteStream instance. + """ + return cls.from_base64(data["data"], encoding=encoding, meta=data.get("meta"), mime_type=data.get("mime_type")) + + def to_dict(self, encoding: str = "utf-8"): + """ + Convert the ByteStream to a dictionary. + + :param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8". + :returns: The dictionary representation of the ByteStream. + :raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. + """ + return {"data": self.to_base64(encoding=encoding), "meta": self.meta, "mime_type": self.mime_type} diff --git a/test/dataclasses/__init__.py b/test/dataclasses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/dataclasses/test_byte_stream.py b/test/dataclasses/test_byte_stream.py new file mode 100644 index 00000000..21b5af58 --- /dev/null +++ b/test/dataclasses/test_byte_stream.py @@ -0,0 +1,91 @@ +import pytest +from base64 import b64encode +from pathlib import Path +from unittest.mock import mock_open, patch + +from haystack_experimental.dataclasses.byte_stream import ByteStream + +@pytest.fixture +def byte_stream(): + test_data = b"test data" + test_meta = {"key": "value"} + test_mime = "text/plain" + return ByteStream(data=test_data, meta=test_meta, mime_type=test_mime) + +def test_init(byte_stream): + assert byte_stream.data == b"test data" + assert byte_stream.meta == {"key": "value"} + assert byte_stream.mime_type == "text/plain" + +def test_type_property(byte_stream): + assert byte_stream.type == "text" + stream_without_mime = ByteStream(data=b"test data") + assert stream_without_mime.type is None + +def test_subtype_property(byte_stream): + assert byte_stream.subtype == "plain" + stream_without_mime = ByteStream(data=b"test data") + assert stream_without_mime.subtype is None + +@patch("builtins.open", new_callable=mock_open) +def test_to_file(mock_file, byte_stream): + path = Path("test.txt") + byte_stream.to_file(path) + mock_file.assert_called_once_with(path, "wb") + mock_file().write.assert_called_once_with(b"test data") + +@patch("builtins.open", new_callable=mock_open, read_data=b"test data") +def test_from_file_path(mock_file): + path = Path("test.txt") + with patch("mimetypes.guess_type", return_value=("text/plain", None)): + byte_stream = ByteStream.from_file_path(path) + assert byte_stream.data == b"test data" + assert byte_stream.mime_type == "text/plain" + +@patch("mimetypes.guess_type", return_value=(None, None)) +@patch("haystack_experimental.dataclasses.byte_stream.logger.warning") +def test_from_file_path_unknown_mime(mock_warning, _, byte_stream): + path = Path("test.txt") + with patch("builtins.open", new_callable=mock_open, read_data=b"test data"): + byte_stream = ByteStream.from_file_path(path) + assert byte_stream.mime_type is None + mock_warning.assert_called_once() + +def test_from_string(): + text = "Hello, World!" + byte_stream = ByteStream.from_string(text, mime_type="text/plain") + assert byte_stream.data == text.encode("utf-8") + assert byte_stream.mime_type == "text/plain" + +def test_to_string(): + byte_stream = ByteStream(data=b"Hello, World!") + assert byte_stream.to_string() == "Hello, World!" + +def test_from_base64(): + base64_string = b64encode(b"test data").decode("utf-8") + byte_stream = ByteStream.from_base64(base64_string, mime_type="text/plain") + assert byte_stream.data == b"test data" + assert byte_stream.mime_type == "text/plain" + +def test_to_base64(byte_stream): + expected = b64encode(b"test data").decode("utf-8") + assert byte_stream.to_base64() == expected + +def test_from_dict(): + data = { + "data": b64encode(b"test data").decode("utf-8"), + "meta": {"key": "value"}, + "mime_type": "text/plain", + } + byte_stream = ByteStream.from_dict(data) + assert byte_stream.data == b"test data" + assert byte_stream.meta == {"key": "value"} + assert byte_stream.mime_type == "text/plain" + +def test_to_dict(byte_stream): + expected = { + "data": b64encode(b"test data").decode("utf-8"), + "meta": {"key": "value"}, + "mime_type": "text/plain", + } + assert byte_stream.to_dict() == expected