diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 6bcd6cf35d..cfa822038b 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -391,6 +391,8 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches parent_span=span, ) except PipelineRuntimeError as error: + # TODO Wrap creation of the pipeline snapshot with try-except in case it fails + # (e.g. serialization issue) out_dir = _get_output_dir("pipeline_snapshot") break_point = Breakpoint( component_name=component_name, diff --git a/haystack/utils/base_serialization.py b/haystack/utils/base_serialization.py index 0f56508b06..9ab0e2e1d1 100644 --- a/haystack/utils/base_serialization.py +++ b/haystack/utils/base_serialization.py @@ -2,12 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from enum import Enum +from typing import Any, Union +from haystack import logging from haystack.core.errors import DeserializationError, SerializationError from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.utils import deserialize_callable, serialize_callable +logger = logging.getLogger(__name__) + +_PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"} + def serialize_class_instance(obj: Any) -> dict[str, Any]: """ @@ -55,7 +61,7 @@ class does not have a `from_dict` method. return obj_class.from_dict(data["data"]) -def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: +def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: disable=too-many-return-statements """ Serializes a value into a schema-aware format suitable for storage or transmission. @@ -90,10 +96,14 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # Handle array case - iterate through elements elif isinstance(payload, (list, tuple, set)): - # Convert to list for consistent handling - pure_list = _convert_to_basic_types(list(payload)) + # Serialize each item in the array + serialized_list = [] + for item in payload: + serialized_value = _serialize_value_with_schema(item) + serialized_list.append(serialized_value["serialized_data"]) # Determine item type from first element (if any) + # NOTE: We do not support mixed-type lists if payload: first = next(iter(payload)) item_schema = _serialize_value_with_schema(first) @@ -108,31 +118,37 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: base_schema["minItems"] = len(payload) base_schema["maxItems"] = len(payload) - return {"serialization_schema": base_schema, "serialized_data": pure_list} + return {"serialization_schema": base_schema, "serialized_data": serialized_list} # Handle Haystack style objects (e.g. dataclasses and Components) elif hasattr(payload, "to_dict") and callable(payload.to_dict): type_name = generate_qualified_class_name(type(payload)) - pure = _convert_to_basic_types(payload) schema = {"type": type_name} - return {"serialization_schema": schema, "serialized_data": pure} + return {"serialization_schema": schema, "serialized_data": payload.to_dict()} # Handle callable functions serialization elif callable(payload) and not isinstance(payload, type): serialized = serialize_callable(payload) return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized} + # Handle Enums + elif isinstance(payload, Enum): + type_name = generate_qualified_class_name(type(payload)) + return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name} + # Handle arbitrary objects with __dict__ elif hasattr(payload, "__dict__"): type_name = generate_qualified_class_name(type(payload)) - pure = _convert_to_basic_types(vars(payload)) schema = {"type": type_name} - return {"serialization_schema": schema, "serialized_data": pure} + serialized_data = {} + for key, value in vars(payload).items(): + serialized_value = _serialize_value_with_schema(value) + serialized_data[key] = serialized_value["serialized_data"] + return {"serialization_schema": schema, "serialized_data": serialized_data} # Handle primitives else: - prim_type = _primitive_schema_type(payload) - schema = {"type": prim_type} + schema = {"type": _primitive_schema_type(payload)} return {"serialization_schema": schema, "serialized_data": payload} @@ -140,61 +156,16 @@ def _primitive_schema_type(value: Any) -> str: """ Helper function to determine the schema type for primitive values. """ - if value is None: - return "null" - if isinstance(value, bool): - return "boolean" - if isinstance(value, int): - return "integer" - if isinstance(value, float): - return "number" - if isinstance(value, str): - return "string" + for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items(): + if isinstance(value, py_type): + return schema_value + logger.warning( + "Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__ + ) return "string" # fallback -def _convert_to_basic_types(value: Any) -> Any: - """ - Helper function to recursively convert complex Python objects into their basic type equivalents. - - This helper function traverses through nested data structures and converts all complex - objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str, - int, float, bool, None) that can be easily serialized. - - The function handles: - - Objects with to_dict() methods: converted using their to_dict implementation - - Objects with __dict__ attribute: converted to plain dictionaries - - Dictionaries: recursively converted values while preserving keys - - Sequences (list, tuple, set): recursively converted while preserving type - - Function objects: converted to None (functions cannot be serialized) - - Primitive types: returned as-is - - """ - # dataclass‐style objects - if hasattr(value, "to_dict") and callable(value.to_dict): - return _convert_to_basic_types(value.to_dict()) - - # Handle function objects - they cannot be serialized, so we return None - if callable(value) and not isinstance(value, type): - return None - - # arbitrary objects with __dict__ - if hasattr(value, "__dict__"): - return {k: _convert_to_basic_types(v) for k, v in vars(value).items()} - - # dicts - if isinstance(value, dict): - return {k: _convert_to_basic_types(v) for k, v in value.items()} - - # sequences - if isinstance(value, (list, tuple, set)): - return [_convert_to_basic_types(v) for v in value] - - # primitive - return value - - -def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912 +def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: """ Deserializes a value with schema information back to its original form. @@ -204,6 +175,8 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint "serialized_data": } + NOTE: For array types we only support homogeneous lists (all elements of the same type). + :param serialized: The serialized dict with schema and data. :returns: The deserialized value in its original form. """ @@ -229,121 +202,83 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint # Handle object case (dictionary with properties) if schema_type == "object": - properties = schema.get("properties") - if properties: - result: dict[str, Any] = {} - - if isinstance(data, dict): - for field, raw_value in data.items(): - field_schema = properties.get(field) - if field_schema: - # Recursively deserialize each field - avoid creating temporary dict - result[field] = _deserialize_value_with_schema( - {"serialization_schema": field_schema, "serialized_data": raw_value} - ) - - return result - else: - return _deserialize_value(data) + properties = schema["properties"] + result: dict[str, Any] = {} + for field, raw_value in data.items(): + field_schema = properties[field] + # Recursively deserialize each field - avoid creating temporary dict + result[field] = _deserialize_value_with_schema( + {"serialization_schema": field_schema, "serialized_data": raw_value} + ) + return result # Handle array case - elif schema_type == "array": - # Cache frequently accessed schema properties - item_schema = schema.get("items", {}) - item_type = item_schema.get("type", "any") - is_set = schema.get("uniqueItems") is True - is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None - - # Handle nested objects/arrays first (most complex case) - if item_type in ("object", "array"): - return [ - _deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item}) - for item in data - ] - - # Helper function to deserialize individual items - def deserialize_item(item): - if item_type == "any": - return _deserialize_value(item) - else: - return _deserialize_value({"type": item_type, "data": item}) - - # Handle different collection types - if is_set: - return {deserialize_item(item) for item in data} - elif is_tuple: - return tuple(deserialize_item(item) for item in data) + if schema_type == "array": + # Deserialize each item + deserialized_items = [ + _deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item}) + for item in data + ] + final_array: Union[list, set, tuple] + # Is a set if uniqueItems is True + if schema.get("uniqueItems") is True: + final_array = set(deserialized_items) + # Is a tuple if minItems and maxItems are set + elif schema.get("minItems") is not None and schema.get("maxItems") is not None: + final_array = tuple(deserialized_items) else: - return [deserialize_item(item) for item in data] + # Otherwise, it's a list + final_array = list(deserialized_items) + return final_array # Handle primitive types - elif schema_type in ("null", "boolean", "integer", "number", "string"): + if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values(): return data # Handle callable functions - elif schema_type == "typing.Callable": + if schema_type == "typing.Callable": return deserialize_callable(data) # Handle custom class types - else: - return _deserialize_value({"type": schema_type, "data": data}) + return _deserialize_value({"type": schema_type, "data": data}) -def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements # noqa: PLR0911 +def _deserialize_value(value: dict[str, Any]) -> Any: """ Helper function to deserialize values from their envelope format {"type": T, "data": D}. - Handles four cases: - - Typed envelopes: {"type": T, "data": D} where T determines deserialization method - - Plain dicts: recursively deserialize values - - Collections (list/tuple/set): recursively deserialize elements - - Other values: return as-is + This handles: + - Custom classes (with a from_dict method) + - Enums + - Fallback for arbitrary classes (sets attributes on a blank instance) :param value: The value to deserialize - :returns: The deserialized value - + :returns: + The deserialized value + :raises DeserializationError: + If the type cannot be imported or the value is not valid for the type. """ # 1) Envelope case - if isinstance(value, dict) and "type" in value and "data" in value: - t = value["type"] - payload = value["data"] - - # 1.a) Array - if t == "array": - return [_deserialize_value(child) for child in payload] - - # 1.b) Generic object/dict - if t == "object": - return {k: _deserialize_value(v) for k, v in payload.items()} - - # 1.c) Primitive - if t in ("null", "boolean", "integer", "number", "string"): - return payload - - # 1.d) Callable - if t == "typing.Callable": - return deserialize_callable(payload) - - # 1.e) Custom class - cls = import_class_by_name(t) - # first, recursively deserialize the inner payload - deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()} - # try from_dict - if hasattr(cls, "from_dict") and callable(cls.from_dict): - return cls.from_dict(deserialized_payload) - # fallback: set attributes on a blank instance - instance = cls.__new__(cls) - for attr_name, attr_value in deserialized_payload.items(): - setattr(instance, attr_name, attr_value) - return instance - - # 2) Plain dict (no envelope) → recurse - if isinstance(value, dict): - return {k: _deserialize_value(v) for k, v in value.items()} - - # 3) Collections → recurse - if isinstance(value, (list, tuple, set)): - return type(value)(_deserialize_value(v) for v in value) - - # 4) Fallback (shouldn't usually happen with our schema) - return value + value_type = value["type"] + payload = value["data"] + + # Custom class where value_type is a qualified class name + cls = import_class_by_name(value_type) + + # try from_dict (e.g. Haystack dataclasses and Components) + if hasattr(cls, "from_dict") and callable(cls.from_dict): + return cls.from_dict(payload) + + # handle enum types + if issubclass(cls, Enum): + try: + return cls[payload] + except Exception as e: + raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e + + # fallback: set attributes on a blank instance + deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()} + instance = cls.__new__(cls) + for attr_name, attr_value in deserialized_payload.items(): + setattr(instance, attr_name, attr_value) + return instance diff --git a/releasenotes/notes/add-enum-serialization-support-0ff44d00e9474e93.yaml b/releasenotes/notes/add-enum-serialization-support-0ff44d00e9474e93.yaml new file mode 100644 index 0000000000..d3e63cd65b --- /dev/null +++ b/releasenotes/notes/add-enum-serialization-support-0ff44d00e9474e93.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Updated our serialization and deserialization of PipelineSnapshots to work with python Enum classes. diff --git a/test/utils/test_base_serialization.py b/test/utils/test_base_serialization.py index 0217f7dd03..8bb42c0d97 100644 --- a/test/utils/test_base_serialization.py +++ b/test/utils/test_base_serialization.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from enum import Enum + import pytest from haystack.core.errors import DeserializationError, SerializationError @@ -14,6 +16,11 @@ ) +class CustomEnum(Enum): + ONE = "one" + TWO = "two" + + class CustomClass: def to_dict(self): return {"key": "value", "more": False} @@ -76,177 +83,203 @@ def test_deserialize_class_instance_invalid_data(): deserialize_class_instance({"type": "test_base_serialization.CustomClassNoFromDict", "data": {}}) -def test_serialize_value_primitive_types(): - numbers = 1 - string = "test" - _bool = True - none = None - result = _serialize_value_with_schema(numbers) - assert result == {"serialization_schema": {"type": "integer"}, "serialized_data": 1} - result = _serialize_value_with_schema(string) - assert result == {"serialization_schema": {"type": "string"}, "serialized_data": "test"} - result = _serialize_value_with_schema(_bool) - assert result == {"serialization_schema": {"type": "boolean"}, "serialized_data": True} - result = _serialize_value_with_schema(none) - assert result == {"serialization_schema": {"type": "null"}, "serialized_data": None} - - -def test_deserialize_value_primitive_types(): - result = _deserialize_value_with_schema({"serialization_schema": {"type": "integer"}, "serialized_data": 1}) - assert result == 1 - result = _deserialize_value_with_schema({"serialization_schema": {"type": "string"}, "serialized_data": "test"}) - assert result == "test" - result = _deserialize_value_with_schema({"serialization_schema": {"type": "boolean"}, "serialized_data": True}) - assert result == True - result = _deserialize_value_with_schema({"serialization_schema": {"type": "null"}, "serialized_data": None}) - assert result == None - - -def test_serialize_value_with_sequences(): - sequences = [1, 2, 3] - set_sequences = {1, 2, 3} - tuple_sequences = (1, 2, 3) - result = _serialize_value_with_schema(sequences) - assert result == { - "serialization_schema": {"type": "array", "items": {"type": "integer"}}, - "serialized_data": [1, 2, 3], - } - result = _serialize_value_with_schema(set_sequences) - assert result == { - "serialization_schema": {"type": "array", "items": {"type": "integer"}, "uniqueItems": True}, - "serialized_data": [1, 2, 3], - } - result = _serialize_value_with_schema(tuple_sequences) - assert result == { - "serialization_schema": {"type": "array", "items": {"type": "integer"}, "minItems": 3, "maxItems": 3}, - "serialized_data": [1, 2, 3], - } - - -def test_deserialize_value_with_sequences(): - sequences = [1, 2, 3] - set_sequences = {1, 2, 3} - tuple_sequences = (1, 2, 3) - result = _deserialize_value_with_schema( - {"serialization_schema": {"type": "array", "items": {"type": "integer"}}, "serialized_data": [1, 2, 3]} - ) - assert result == sequences - result = _deserialize_value_with_schema( - { - "serialization_schema": {"type": "array", "items": {"type": "integer"}, "uniqueItems": True}, - "serialized_data": [1, 2, 3], - } - ) - assert result == set_sequences - result = _deserialize_value_with_schema( - { - "serialization_schema": { - "type": "array", - "items": {"type": "integer"}, - "collection_type": "tuple", - "minItems": 3, - "maxItems": 3, +@pytest.mark.parametrize( + "value,result", + [ + # integer + (1, {"serialization_schema": {"type": "integer"}, "serialized_data": 1}), + # float + (1.5, {"serialization_schema": {"type": "number"}, "serialized_data": 1.5}), + # string + ("test", {"serialization_schema": {"type": "string"}, "serialized_data": "test"}), + # boolean + (True, {"serialization_schema": {"type": "boolean"}, "serialized_data": True}), + (False, {"serialization_schema": {"type": "boolean"}, "serialized_data": False}), + # None + (None, {"serialization_schema": {"type": "null"}, "serialized_data": None}), + ], +) +def test_serialize_and_deserialize_primitive_types(value, result): + assert _serialize_value_with_schema(value) == result + assert _deserialize_value_with_schema(result) == value + + +@pytest.mark.parametrize( + "value,result", + [ + # empty dict + ({}, {"serialization_schema": {"type": "object", "properties": {}}, "serialized_data": {}}), + # empty list + ([], {"serialization_schema": {"type": "array", "items": {}}, "serialized_data": []}), + # empty tuple + ( + (), + { + "serialization_schema": {"type": "array", "items": {}, "minItems": 0, "maxItems": 0}, + "serialized_data": [], }, - "serialized_data": [1, 2, 3], - } - ) - assert result == tuple_sequences - - -def test_serializing_and_deserializing_nested_lists(): - nested_lists = [[1, 2], [3, 4]] - - serialized_nested_lists = _serialize_value_with_schema(nested_lists) - assert serialized_nested_lists == { - "serialization_schema": {"type": "array", "items": {"type": "array", "items": {"type": "integer"}}}, - "serialized_data": [[1, 2], [3, 4]], - } - - deserialized_nested_lists = _deserialize_value_with_schema(serialized_nested_lists) - assert deserialized_nested_lists == nested_lists - - -def test_serializing_and_deserializing_nested_answer_lists(): - """Test that _deserialize_value_with_schema handles nested lists""" - - nested_answers_list = [ - [ - GeneratedAnswer( - data="Paris", - query="What is the capital of France?", - documents=[Document(content="Paris is the capital of France")], - meta={"page": 1}, - ) - ], - [ - GeneratedAnswer( - data="Berlin", - query="What is the capital of Germany?", - documents=[Document(content="Berlin is the capital of Germany")], - meta={"page": 1}, - ) - ], - ] - serialized_nested_answers_list = _serialize_value_with_schema(nested_answers_list) - assert serialized_nested_answers_list == { - "serialization_schema": { - "type": "array", - "items": {"type": "array", "items": {"type": "haystack.dataclasses.answer.GeneratedAnswer"}}, - }, - "serialized_data": [ - [ - { - "type": "haystack.dataclasses.answer.GeneratedAnswer", - "init_parameters": { - "data": "Paris", - "query": "What is the capital of France?", - "documents": [ - { - "id": "413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", - "content": "Paris is the capital of France", - "blob": None, - "meta": {}, - "score": None, - "embedding": None, - "sparse_embedding": None, - } - ], - "meta": {"page": 1}, + ), + # empty set + (set(), {"serialization_schema": {"type": "array", "items": {}, "uniqueItems": True}, "serialized_data": []}), + # nested empty structures + ( + {"empty_list": [], "empty_dict": {}, "nested_empty": {"empty": []}}, + { + "serialization_schema": { + "type": "object", + "properties": { + "empty_list": {"type": "array", "items": {}}, + "empty_dict": {"type": "object", "properties": {}}, + "nested_empty": {"type": "object", "properties": {"empty": {"type": "array", "items": {}}}}, }, - } - ], + }, + "serialized_data": {"empty_list": [], "empty_dict": {}, "nested_empty": {"empty": []}}, + }, + ), + ], +) +def test_serializing_and_deserializing_empty_structures(value, result): + assert _serialize_value_with_schema(value) == result + assert _deserialize_value_with_schema(result) == value + + +@pytest.mark.parametrize( + "value,result", + [ + # list + ( + [1, 2, 3], + {"serialization_schema": {"type": "array", "items": {"type": "integer"}}, "serialized_data": [1, 2, 3]}, + ), + # set + ( + {1, 2, 3}, + { + "serialization_schema": {"type": "array", "items": {"type": "integer"}, "uniqueItems": True}, + "serialized_data": [1, 2, 3], + }, + ), + # tuple + ( + (1, 2, 3), + { + "serialization_schema": {"type": "array", "items": {"type": "integer"}, "minItems": 3, "maxItems": 3}, + "serialized_data": [1, 2, 3], + }, + ), + # nested list + ( + [[1, 2], [3, 4]], + { + "serialization_schema": {"type": "array", "items": {"type": "array", "items": {"type": "integer"}}}, + "serialized_data": [[1, 2], [3, 4]], + }, + ), + # list of set + ( + [{1, 2}, {3, 4}], + { + "serialization_schema": { + "items": {"items": {"type": "integer"}, "type": "array", "uniqueItems": True}, + "type": "array", + }, + "serialized_data": [[1, 2], [3, 4]], + }, + ), + # nested tuple + ( + ((1, 2), (3, 4), (5, 6)), + { + "serialization_schema": { + "type": "array", + "items": {"type": "array", "items": {"type": "integer"}, "minItems": 2, "maxItems": 2}, + "minItems": 3, + "maxItems": 3, + }, + "serialized_data": [[1, 2], [3, 4], [5, 6]], + }, + ), + # nested list of GeneratedAnswer + ( [ - { - "type": "haystack.dataclasses.answer.GeneratedAnswer", - "init_parameters": { - "data": "Berlin", - "query": "What is the capital of Germany?", - "documents": [ - { - "id": "c7b5b839963fcbf9b394b24c883731e840c3170ace33afb7af87a2de8a257f6f", - "content": "Berlin is the capital of Germany", - "blob": None, - "meta": {}, - "score": None, - "embedding": None, - "sparse_embedding": None, - } - ], - "meta": {"page": 1}, - }, - } + [ + GeneratedAnswer( + data="Paris", + query="What is the capital of France?", + documents=[Document(content="Paris is the capital of France", id="1")], + meta={"page": 1}, + ) + ], + [ + GeneratedAnswer( + data="Berlin", + query="What is the capital of Germany?", + documents=[Document(content="Berlin is the capital of Germany", id="2")], + meta={"page": 1}, + ) + ], ], - ], - } - - deserialized_nested_answers_list = _deserialize_value_with_schema(serialized_nested_answers_list) - assert deserialized_nested_answers_list == nested_answers_list + { + "serialization_schema": { + "type": "array", + "items": {"type": "array", "items": {"type": "haystack.dataclasses.answer.GeneratedAnswer"}}, + }, + "serialized_data": [ + [ + { + "type": "haystack.dataclasses.answer.GeneratedAnswer", + "init_parameters": { + "data": "Paris", + "query": "What is the capital of France?", + "documents": [ + { + "id": "1", + "content": "Paris is the capital of France", + "blob": None, + "meta": {}, + "score": None, + "embedding": None, + "sparse_embedding": None, + } + ], + "meta": {"page": 1}, + }, + } + ], + [ + { + "type": "haystack.dataclasses.answer.GeneratedAnswer", + "init_parameters": { + "data": "Berlin", + "query": "What is the capital of Germany?", + "documents": [ + { + "id": "2", + "content": "Berlin is the capital of Germany", + "blob": None, + "meta": {}, + "score": None, + "embedding": None, + "sparse_embedding": None, + } + ], + "meta": {"page": 1}, + }, + } + ], + ], + }, + ), + ], +) +def test_serialize_and_deserialize_sequence_types(value, result): + assert _serialize_value_with_schema(value) == result + assert _deserialize_value_with_schema(result) == value -def test_serializing_and_deserializing_nested_dicts(): +def test_serialize_and_deserialize_nested_dicts(): data = {"key1": {"nested1": "value1", "nested2": {"deep": "value2"}}} - serialized_nested_dicts = _serialize_value_with_schema(data) - assert serialized_nested_dicts == { + expected = { "serialization_schema": { "type": "object", "properties": { @@ -261,63 +294,28 @@ def test_serializing_and_deserializing_nested_dicts(): }, "serialized_data": {"key1": {"nested1": "value1", "nested2": {"deep": "value2"}}}, } - - deserialized_nested_dicts = _deserialize_value_with_schema(serialized_nested_dicts) - assert deserialized_nested_dicts == data - - -def test_serializing_and_deserializing_nested_sets(): - nested_sets = [{1, 2}, {3, 4}] - - result = _serialize_value_with_schema(nested_sets) - assert result == { - "serialization_schema": { - "items": {"items": {"type": "integer"}, "type": "array", "uniqueItems": True}, - "type": "array", - }, - "serialized_data": [[1, 2], [3, 4]], - } - - result = _deserialize_value_with_schema( - { - "serialization_schema": { - "items": {"items": {"type": "integer"}, "type": "array", "uniqueItems": True}, - "type": "array", - }, - "serialized_data": [[1, 2], [3, 4]], - } - ) - assert result == nested_sets - - -def test_serializing_and_deserializing_empty_structures(): - """Test that _deserialize_value_with_schema handles empty structures""" - data = {"empty_list": [], "empty_dict": {}, "nested_empty": {"empty": []}} - serialized_data = _serialize_value_with_schema(data) - result = _deserialize_value_with_schema(serialized_data) - - assert result == data + assert _serialize_value_with_schema(data) == expected + assert _deserialize_value_with_schema(expected) == data -def test_serialize_value_with_schema(): +def test_serialize_and_deserialize_value_with_schema_with_various_types(): data = { "numbers": 1, "messages": [ChatMessage.from_user(text="Hello, world!"), ChatMessage.from_assistant(text="Hello, world!")], "user_id": "123", "dict_of_lists": {"numbers": [1, 2, 3]}, - "documents": [Document(content="Hello, world!")], + "documents": [Document(content="Hello, world!", id="1")], "list_of_dicts": [{"numbers": [1, 2, 3]}], "answers": [ GeneratedAnswer( data="Paris", query="What is the capital of France?", - documents=[Document(content="Paris is the capital of France")], + documents=[Document(content="Paris is the capital of France", id="2")], meta={"page": 1}, ) ], } - result = _serialize_value_with_schema(data) - assert result == { + expected = { "serialization_schema": { "type": "object", "properties": { @@ -349,7 +347,7 @@ def test_serialize_value_with_schema(): "dict_of_lists": {"numbers": [1, 2, 3]}, "documents": [ { - "id": "e0f8c9e42f5535600aee6c5224bf4478b73bcf0a1bcba6f357bf162e88ff985d", + "id": "1", "content": "Hello, world!", "blob": None, "score": None, @@ -366,7 +364,7 @@ def test_serialize_value_with_schema(): "query": "What is the capital of France?", "documents": [ { - "id": "413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", + "id": "2", "content": "Paris is the capital of France", "blob": None, "meta": {}, @@ -381,77 +379,8 @@ def test_serialize_value_with_schema(): ], }, } - - -def test_deserialize_value_with_schema(): - serialized__data = { - "serialization_schema": { - "type": "object", - "properties": { - "numbers": {"type": "integer"}, - "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, - "user_id": {"type": "string"}, - "dict_of_lists": { - "type": "object", - "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, - }, - "documents": {"type": "array", "items": {"type": "haystack.dataclasses.document.Document"}}, - "list_of_dicts": {"type": "array", "items": {"type": "string"}}, - "answers": {"type": "array", "items": {"type": "haystack.dataclasses.answer.GeneratedAnswer"}}, - }, - }, - "serialized_data": { - "numbers": 1, - "messages": [ - {"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}, - {"role": "assistant", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}, - ], - "user_id": "123", - "dict_of_lists": {"numbers": [1, 2, 3]}, - "documents": [ - { - "id": "e0f8c9e42f5535600aee6c5224bf4478b73bcf0a1bcba6f357bf162e88ff985d", - "content": "Hello, world!", - "blob": None, - "score": None, - "embedding": None, - "sparse_embedding": None, - } - ], - "list_of_dicts": [{"numbers": [1, 2, 3]}], - "answers": [ - { - "type": "haystack.dataclasses.answer.GeneratedAnswer", - "init_parameters": { - "data": "Paris", - "query": "What is the capital of France?", - "documents": [ - { - "id": "413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651", - "content": "Paris is the capital of France", - "blob": None, - "meta": {}, - "score": None, - "embedding": None, - "sparse_embedding": None, - } - ], - "meta": {"page": 1}, - }, - } - ], - }, - } - - result = _deserialize_value_with_schema(serialized__data) - assert result["numbers"] == 1 - assert isinstance(result["messages"][0], ChatMessage) - assert result["messages"][0].text == "Hello, world!" - assert result["user_id"] == "123" - assert result["dict_of_lists"] == {"numbers": [1, 2, 3]} - assert isinstance(result["documents"][0], Document) - assert result["documents"][0].content == "Hello, world!" - assert isinstance(result["answers"][0], GeneratedAnswer) + assert _serialize_value_with_schema(data) == expected + assert _deserialize_value_with_schema(expected) == data def test_serializing_and_deserializing_custom_class_type(): @@ -474,20 +403,24 @@ def test_serializing_and_deserializing_custom_class_type(): assert isinstance(deserialized_data["custom_type"], CustomClass) -def test_serialize_value_with_callable(): - result = _serialize_value_with_schema(simple_calc_function) - assert result == { +def test_serialize_and_deserialize_value_with_callable(): + expected = { "serialization_schema": {"type": "typing.Callable"}, "serialized_data": "test_base_serialization.simple_calc_function", } + assert _serialize_value_with_schema(simple_calc_function) == expected + assert _deserialize_value_with_schema(expected) == simple_calc_function -def test_deserialize_value_with_callable(): - serialized_data = { - "serialization_schema": {"type": "typing.Callable"}, - "serialized_data": "test_base_serialization.simple_calc_function", - } +def test_serialize_and_deserialize_value_with_enum(): + data = CustomEnum.ONE + expected = {"serialization_schema": {"type": "test_base_serialization.CustomEnum"}, "serialized_data": "ONE"} + assert _serialize_value_with_schema(data) == expected + assert _deserialize_value_with_schema(expected) == data + - result = _deserialize_value_with_schema(serialized_data) - assert result is simple_calc_function - assert result(5) == 10 +def test_deserialize_value_with_wrong_value(): + with pytest.raises(DeserializationError, match="Value 'NOT_VALID' is not a valid member of Enum"): + _deserialize_value_with_schema( + {"serialization_schema": {"type": "test_base_serialization.CustomEnum"}, "serialized_data": "NOT_VALID"} + )