-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat: Add serialization and deserialization of Enum type when creating a PipelineSnaphsot
#9869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
edc7161
3a23dc1
f956038
a61227c
fed4e0f
d27a00e
f11884d
e614111
962d3ac
41f3c73
5b28bce
0590a42
32492ad
61a015d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,18 @@ | |
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Any | ||
| from enum import Enum | ||
| from typing import Any, Union | ||
|
|
||
| from haystack import logging | ||
| from haystack.core.errors import DeserializationError, SerializationError | ||
| from haystack.core.serialization import generate_qualified_class_name, import_class_by_name | ||
| from haystack.utils import deserialize_callable, serialize_callable | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"} | ||
|
|
||
|
|
||
| def serialize_class_instance(obj: Any) -> dict[str, Any]: | ||
| """ | ||
|
|
@@ -55,7 +61,7 @@ class does not have a `from_dict` method. | |
| return obj_class.from_dict(data["data"]) | ||
|
|
||
|
|
||
| def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: | ||
| def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: disable=too-many-return-statements | ||
| """ | ||
| Serializes a value into a schema-aware format suitable for storage or transmission. | ||
|
|
||
|
|
@@ -90,10 +96,14 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: | |
|
|
||
| # Handle array case - iterate through elements | ||
| elif isinstance(payload, (list, tuple, set)): | ||
| # Convert to list for consistent handling | ||
| pure_list = _convert_to_basic_types(list(payload)) | ||
| # Serialize each item in the array | ||
| serialized_list = [] | ||
| for item in payload: | ||
| serialized_value = _serialize_value_with_schema(item) | ||
| serialized_list.append(serialized_value["serialized_data"]) | ||
|
|
||
| # Determine item type from first element (if any) | ||
| # NOTE: We do not support mixed-type lists | ||
| if payload: | ||
| first = next(iter(payload)) | ||
| item_schema = _serialize_value_with_schema(first) | ||
|
|
@@ -108,93 +118,54 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: | |
| base_schema["minItems"] = len(payload) | ||
| base_schema["maxItems"] = len(payload) | ||
|
|
||
| return {"serialization_schema": base_schema, "serialized_data": pure_list} | ||
| return {"serialization_schema": base_schema, "serialized_data": serialized_list} | ||
|
|
||
| # Handle Haystack style objects (e.g. dataclasses and Components) | ||
| elif hasattr(payload, "to_dict") and callable(payload.to_dict): | ||
| type_name = generate_qualified_class_name(type(payload)) | ||
| pure = _convert_to_basic_types(payload) | ||
| schema = {"type": type_name} | ||
| return {"serialization_schema": schema, "serialized_data": pure} | ||
| return {"serialization_schema": schema, "serialized_data": payload.to_dict()} | ||
|
|
||
| # Handle callable functions serialization | ||
| elif callable(payload) and not isinstance(payload, type): | ||
| serialized = serialize_callable(payload) | ||
| return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized} | ||
|
|
||
| # Handle Enums | ||
| elif isinstance(payload, Enum): | ||
| type_name = generate_qualified_class_name(type(payload)) | ||
| return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name} | ||
sjrl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Handle arbitrary objects with __dict__ | ||
| elif hasattr(payload, "__dict__"): | ||
| type_name = generate_qualified_class_name(type(payload)) | ||
| pure = _convert_to_basic_types(vars(payload)) | ||
| schema = {"type": type_name} | ||
| return {"serialization_schema": schema, "serialized_data": pure} | ||
| serialized_data = {} | ||
| for key, value in vars(payload).items(): | ||
| serialized_value = _serialize_value_with_schema(value) | ||
| serialized_data[key] = serialized_value["serialized_data"] | ||
| return {"serialization_schema": schema, "serialized_data": serialized_data} | ||
|
|
||
| # Handle primitives | ||
| else: | ||
| prim_type = _primitive_schema_type(payload) | ||
| schema = {"type": prim_type} | ||
| schema = {"type": _primitive_schema_type(payload)} | ||
| return {"serialization_schema": schema, "serialized_data": payload} | ||
|
|
||
|
|
||
| def _primitive_schema_type(value: Any) -> str: | ||
| """ | ||
| Helper function to determine the schema type for primitive values. | ||
| """ | ||
| if value is None: | ||
| return "null" | ||
| if isinstance(value, bool): | ||
| return "boolean" | ||
| if isinstance(value, int): | ||
| return "integer" | ||
| if isinstance(value, float): | ||
| return "number" | ||
| if isinstance(value, str): | ||
| return "string" | ||
| for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items(): | ||
| if isinstance(value, py_type): | ||
| return schema_value | ||
| logger.warning( | ||
| "Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__ | ||
| ) | ||
| return "string" # fallback | ||
sjrl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _convert_to_basic_types(value: Any) -> Any: | ||
| """ | ||
| Helper function to recursively convert complex Python objects into their basic type equivalents. | ||
|
Comment on lines
-156
to
-158
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By handling every base case in |
||
|
|
||
| This helper function traverses through nested data structures and converts all complex | ||
| objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str, | ||
| int, float, bool, None) that can be easily serialized. | ||
|
|
||
| The function handles: | ||
| - Objects with to_dict() methods: converted using their to_dict implementation | ||
| - Objects with __dict__ attribute: converted to plain dictionaries | ||
| - Dictionaries: recursively converted values while preserving keys | ||
| - Sequences (list, tuple, set): recursively converted while preserving type | ||
| - Function objects: converted to None (functions cannot be serialized) | ||
| - Primitive types: returned as-is | ||
|
|
||
| """ | ||
| # dataclass‐style objects | ||
| if hasattr(value, "to_dict") and callable(value.to_dict): | ||
| return _convert_to_basic_types(value.to_dict()) | ||
|
|
||
| # Handle function objects - they cannot be serialized, so we return None | ||
| if callable(value) and not isinstance(value, type): | ||
| return None | ||
|
|
||
| # arbitrary objects with __dict__ | ||
| if hasattr(value, "__dict__"): | ||
| return {k: _convert_to_basic_types(v) for k, v in vars(value).items()} | ||
|
|
||
| # dicts | ||
| if isinstance(value, dict): | ||
| return {k: _convert_to_basic_types(v) for k, v in value.items()} | ||
|
|
||
| # sequences | ||
| if isinstance(value, (list, tuple, set)): | ||
| return [_convert_to_basic_types(v) for v in value] | ||
|
|
||
| # primitive | ||
| return value | ||
|
|
||
|
|
||
| def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912 | ||
| def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: | ||
| """ | ||
| Deserializes a value with schema information back to its original form. | ||
|
|
||
|
|
@@ -204,6 +175,8 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint | |
| "serialized_data": <the actual data> | ||
| } | ||
|
|
||
| NOTE: For array types we only support homogeneous lists (all elements of the same type). | ||
|
|
||
| :param serialized: The serialized dict with schema and data. | ||
| :returns: The deserialized value in its original form. | ||
| """ | ||
|
|
@@ -229,121 +202,83 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint | |
|
|
||
| # Handle object case (dictionary with properties) | ||
| if schema_type == "object": | ||
| properties = schema.get("properties") | ||
| if properties: | ||
| result: dict[str, Any] = {} | ||
|
|
||
| if isinstance(data, dict): | ||
| for field, raw_value in data.items(): | ||
| field_schema = properties.get(field) | ||
| if field_schema: | ||
| # Recursively deserialize each field - avoid creating temporary dict | ||
| result[field] = _deserialize_value_with_schema( | ||
| {"serialization_schema": field_schema, "serialized_data": raw_value} | ||
| ) | ||
|
|
||
| return result | ||
| else: | ||
| return _deserialize_value(data) | ||
| properties = schema["properties"] | ||
| result: dict[str, Any] = {} | ||
| for field, raw_value in data.items(): | ||
| field_schema = properties[field] | ||
| # Recursively deserialize each field - avoid creating temporary dict | ||
| result[field] = _deserialize_value_with_schema( | ||
| {"serialization_schema": field_schema, "serialized_data": raw_value} | ||
| ) | ||
| return result | ||
|
|
||
| # Handle array case | ||
| elif schema_type == "array": | ||
| # Cache frequently accessed schema properties | ||
| item_schema = schema.get("items", {}) | ||
| item_type = item_schema.get("type", "any") | ||
| is_set = schema.get("uniqueItems") is True | ||
| is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None | ||
|
|
||
| # Handle nested objects/arrays first (most complex case) | ||
| if item_type in ("object", "array"): | ||
| return [ | ||
| _deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item}) | ||
| for item in data | ||
| ] | ||
|
|
||
| # Helper function to deserialize individual items | ||
| def deserialize_item(item): | ||
| if item_type == "any": | ||
| return _deserialize_value(item) | ||
| else: | ||
| return _deserialize_value({"type": item_type, "data": item}) | ||
|
|
||
| # Handle different collection types | ||
| if is_set: | ||
| return {deserialize_item(item) for item in data} | ||
| elif is_tuple: | ||
| return tuple(deserialize_item(item) for item in data) | ||
| if schema_type == "array": | ||
| # Deserialize each item | ||
| deserialized_items = [ | ||
| _deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item}) | ||
| for item in data | ||
| ] | ||
| final_array: Union[list, set, tuple] | ||
| # Is a set if uniqueItems is True | ||
| if schema.get("uniqueItems") is True: | ||
| final_array = set(deserialized_items) | ||
| # Is a tuple if minItems and maxItems are set | ||
| elif schema.get("minItems") is not None and schema.get("maxItems") is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to explicitly check if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think checking for them now is enough. This was our check from before, I'll update the comment |
||
| final_array = tuple(deserialized_items) | ||
| else: | ||
| return [deserialize_item(item) for item in data] | ||
| # Otherwise, it's a list | ||
| final_array = list(deserialized_items) | ||
| return final_array | ||
|
|
||
| # Handle primitive types | ||
| elif schema_type in ("null", "boolean", "integer", "number", "string"): | ||
| if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values(): | ||
| return data | ||
|
|
||
| # Handle callable functions | ||
| elif schema_type == "typing.Callable": | ||
| if schema_type == "typing.Callable": | ||
| return deserialize_callable(data) | ||
|
|
||
| # Handle custom class types | ||
| else: | ||
| return _deserialize_value({"type": schema_type, "data": data}) | ||
| return _deserialize_value({"type": schema_type, "data": data}) | ||
|
|
||
|
|
||
| def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements # noqa: PLR0911 | ||
| def _deserialize_value(value: dict[str, Any]) -> Any: | ||
| """ | ||
| Helper function to deserialize values from their envelope format {"type": T, "data": D}. | ||
|
|
||
| Handles four cases: | ||
| - Typed envelopes: {"type": T, "data": D} where T determines deserialization method | ||
| - Plain dicts: recursively deserialize values | ||
| - Collections (list/tuple/set): recursively deserialize elements | ||
| - Other values: return as-is | ||
| This handles: | ||
| - Custom classes (with a from_dict method) | ||
| - Enums | ||
| - Fallback for arbitrary classes (sets attributes on a blank instance) | ||
|
|
||
| :param value: The value to deserialize | ||
| :returns: The deserialized value | ||
|
|
||
| :returns: | ||
| The deserialized value | ||
| :raises DeserializationError: | ||
| If the type cannot be imported or the value is not valid for the type. | ||
| """ | ||
| # 1) Envelope case | ||
| if isinstance(value, dict) and "type" in value and "data" in value: | ||
| t = value["type"] | ||
| payload = value["data"] | ||
|
|
||
| # 1.a) Array | ||
| if t == "array": | ||
| return [_deserialize_value(child) for child in payload] | ||
|
|
||
| # 1.b) Generic object/dict | ||
| if t == "object": | ||
| return {k: _deserialize_value(v) for k, v in payload.items()} | ||
|
|
||
| # 1.c) Primitive | ||
| if t in ("null", "boolean", "integer", "number", "string"): | ||
| return payload | ||
|
|
||
| # 1.d) Callable | ||
| if t == "typing.Callable": | ||
| return deserialize_callable(payload) | ||
|
|
||
| # 1.e) Custom class | ||
| cls = import_class_by_name(t) | ||
| # first, recursively deserialize the inner payload | ||
| deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()} | ||
| # try from_dict | ||
| if hasattr(cls, "from_dict") and callable(cls.from_dict): | ||
| return cls.from_dict(deserialized_payload) | ||
| # fallback: set attributes on a blank instance | ||
| instance = cls.__new__(cls) | ||
| for attr_name, attr_value in deserialized_payload.items(): | ||
| setattr(instance, attr_name, attr_value) | ||
| return instance | ||
|
|
||
| # 2) Plain dict (no envelope) → recurse | ||
| if isinstance(value, dict): | ||
| return {k: _deserialize_value(v) for k, v in value.items()} | ||
|
|
||
| # 3) Collections → recurse | ||
| if isinstance(value, (list, tuple, set)): | ||
| return type(value)(_deserialize_value(v) for v in value) | ||
|
|
||
| # 4) Fallback (shouldn't usually happen with our schema) | ||
| return value | ||
| value_type = value["type"] | ||
| payload = value["data"] | ||
|
|
||
| # Custom class where value_type is a qualified class name | ||
| cls = import_class_by_name(value_type) | ||
|
|
||
| # try from_dict (e.g. Haystack dataclasses and Components) | ||
| if hasattr(cls, "from_dict") and callable(cls.from_dict): | ||
| return cls.from_dict(payload) | ||
|
|
||
| # handle enum types | ||
| if issubclass(cls, Enum): | ||
| try: | ||
| return cls[payload] | ||
| except Exception as e: | ||
| raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e | ||
sjrl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # fallback: set attributes on a blank instance | ||
| deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()} | ||
| instance = cls.__new__(cls) | ||
| for attr_name, attr_value in deserialized_payload.items(): | ||
| setattr(instance, attr_name, attr_value) | ||
| return instance | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| --- | ||
| features: | ||
| - | | ||
| Updated our serialization and deserialization of PipelineSnapshots to work with python Enum classes. |
Uh oh!
There was an error while loading. Please reload this page.