Skip to content

Commit 512dd86

Browse files
authored
feat: Add serialization and deserialization of Enum type when creating a PipelineSnaphsot (#9869)
* refactor tests * Test refactoring and add failing test for enum * Remove redundant method * Slight refactoring * refactoring * simplification of _deserialize_value_with_schema and _deserialize_value * Add some more TODOs * Add support for enum serialization and deserialization * types * Add reno * fix linting * PR comments * Add warning message * dev comment
1 parent 18b6482 commit 512dd86

File tree

4 files changed

+324
-450
lines changed

4 files changed

+324
-450
lines changed

haystack/core/pipeline/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches
391391
parent_span=span,
392392
)
393393
except PipelineRuntimeError as error:
394+
# TODO Wrap creation of the pipeline snapshot with try-except in case it fails
395+
# (e.g. serialization issue)
394396
out_dir = _get_output_dir("pipeline_snapshot")
395397
break_point = Breakpoint(
396398
component_name=component_name,

haystack/utils/base_serialization.py

Lines changed: 96 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any
5+
from enum import Enum
6+
from typing import Any, Union
67

8+
from haystack import logging
79
from haystack.core.errors import DeserializationError, SerializationError
810
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
911
from haystack.utils import deserialize_callable, serialize_callable
1012

13+
logger = logging.getLogger(__name__)
14+
15+
_PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"}
16+
1117

1218
def serialize_class_instance(obj: Any) -> dict[str, Any]:
1319
"""
@@ -55,7 +61,7 @@ class does not have a `from_dict` method.
5561
return obj_class.from_dict(data["data"])
5662

5763

58-
def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:
64+
def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: disable=too-many-return-statements
5965
"""
6066
Serializes a value into a schema-aware format suitable for storage or transmission.
6167
@@ -90,10 +96,14 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:
9096

9197
# Handle array case - iterate through elements
9298
elif isinstance(payload, (list, tuple, set)):
93-
# Convert to list for consistent handling
94-
pure_list = _convert_to_basic_types(list(payload))
99+
# Serialize each item in the array
100+
serialized_list = []
101+
for item in payload:
102+
serialized_value = _serialize_value_with_schema(item)
103+
serialized_list.append(serialized_value["serialized_data"])
95104

96105
# Determine item type from first element (if any)
106+
# NOTE: We do not support mixed-type lists
97107
if payload:
98108
first = next(iter(payload))
99109
item_schema = _serialize_value_with_schema(first)
@@ -108,93 +118,54 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:
108118
base_schema["minItems"] = len(payload)
109119
base_schema["maxItems"] = len(payload)
110120

111-
return {"serialization_schema": base_schema, "serialized_data": pure_list}
121+
return {"serialization_schema": base_schema, "serialized_data": serialized_list}
112122

113123
# Handle Haystack style objects (e.g. dataclasses and Components)
114124
elif hasattr(payload, "to_dict") and callable(payload.to_dict):
115125
type_name = generate_qualified_class_name(type(payload))
116-
pure = _convert_to_basic_types(payload)
117126
schema = {"type": type_name}
118-
return {"serialization_schema": schema, "serialized_data": pure}
127+
return {"serialization_schema": schema, "serialized_data": payload.to_dict()}
119128

120129
# Handle callable functions serialization
121130
elif callable(payload) and not isinstance(payload, type):
122131
serialized = serialize_callable(payload)
123132
return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized}
124133

134+
# Handle Enums
135+
elif isinstance(payload, Enum):
136+
type_name = generate_qualified_class_name(type(payload))
137+
return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name}
138+
125139
# Handle arbitrary objects with __dict__
126140
elif hasattr(payload, "__dict__"):
127141
type_name = generate_qualified_class_name(type(payload))
128-
pure = _convert_to_basic_types(vars(payload))
129142
schema = {"type": type_name}
130-
return {"serialization_schema": schema, "serialized_data": pure}
143+
serialized_data = {}
144+
for key, value in vars(payload).items():
145+
serialized_value = _serialize_value_with_schema(value)
146+
serialized_data[key] = serialized_value["serialized_data"]
147+
return {"serialization_schema": schema, "serialized_data": serialized_data}
131148

132149
# Handle primitives
133150
else:
134-
prim_type = _primitive_schema_type(payload)
135-
schema = {"type": prim_type}
151+
schema = {"type": _primitive_schema_type(payload)}
136152
return {"serialization_schema": schema, "serialized_data": payload}
137153

138154

139155
def _primitive_schema_type(value: Any) -> str:
140156
"""
141157
Helper function to determine the schema type for primitive values.
142158
"""
143-
if value is None:
144-
return "null"
145-
if isinstance(value, bool):
146-
return "boolean"
147-
if isinstance(value, int):
148-
return "integer"
149-
if isinstance(value, float):
150-
return "number"
151-
if isinstance(value, str):
152-
return "string"
159+
for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items():
160+
if isinstance(value, py_type):
161+
return schema_value
162+
logger.warning(
163+
"Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__
164+
)
153165
return "string" # fallback
154166

155167

156-
def _convert_to_basic_types(value: Any) -> Any:
157-
"""
158-
Helper function to recursively convert complex Python objects into their basic type equivalents.
159-
160-
This helper function traverses through nested data structures and converts all complex
161-
objects (custom classes, dataclasses, etc.) into basic Python types (dict, list, str,
162-
int, float, bool, None) that can be easily serialized.
163-
164-
The function handles:
165-
- Objects with to_dict() methods: converted using their to_dict implementation
166-
- Objects with __dict__ attribute: converted to plain dictionaries
167-
- Dictionaries: recursively converted values while preserving keys
168-
- Sequences (list, tuple, set): recursively converted while preserving type
169-
- Function objects: converted to None (functions cannot be serialized)
170-
- Primitive types: returned as-is
171-
172-
"""
173-
# dataclass‐style objects
174-
if hasattr(value, "to_dict") and callable(value.to_dict):
175-
return _convert_to_basic_types(value.to_dict())
176-
177-
# Handle function objects - they cannot be serialized, so we return None
178-
if callable(value) and not isinstance(value, type):
179-
return None
180-
181-
# arbitrary objects with __dict__
182-
if hasattr(value, "__dict__"):
183-
return {k: _convert_to_basic_types(v) for k, v in vars(value).items()}
184-
185-
# dicts
186-
if isinstance(value, dict):
187-
return {k: _convert_to_basic_types(v) for k, v in value.items()}
188-
189-
# sequences
190-
if isinstance(value, (list, tuple, set)):
191-
return [_convert_to_basic_types(v) for v in value]
192-
193-
# primitive
194-
return value
195-
196-
197-
def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint: disable=too-many-return-statements, # noqa: PLR0911, PLR0912
168+
def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any:
198169
"""
199170
Deserializes a value with schema information back to its original form.
200171
@@ -204,6 +175,8 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint
204175
"serialized_data": <the actual data>
205176
}
206177
178+
NOTE: For array types we only support homogeneous lists (all elements of the same type).
179+
207180
:param serialized: The serialized dict with schema and data.
208181
:returns: The deserialized value in its original form.
209182
"""
@@ -229,121 +202,83 @@ def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any: # pylint
229202

230203
# Handle object case (dictionary with properties)
231204
if schema_type == "object":
232-
properties = schema.get("properties")
233-
if properties:
234-
result: dict[str, Any] = {}
235-
236-
if isinstance(data, dict):
237-
for field, raw_value in data.items():
238-
field_schema = properties.get(field)
239-
if field_schema:
240-
# Recursively deserialize each field - avoid creating temporary dict
241-
result[field] = _deserialize_value_with_schema(
242-
{"serialization_schema": field_schema, "serialized_data": raw_value}
243-
)
244-
245-
return result
246-
else:
247-
return _deserialize_value(data)
205+
properties = schema["properties"]
206+
result: dict[str, Any] = {}
207+
for field, raw_value in data.items():
208+
field_schema = properties[field]
209+
# Recursively deserialize each field - avoid creating temporary dict
210+
result[field] = _deserialize_value_with_schema(
211+
{"serialization_schema": field_schema, "serialized_data": raw_value}
212+
)
213+
return result
248214

249215
# Handle array case
250-
elif schema_type == "array":
251-
# Cache frequently accessed schema properties
252-
item_schema = schema.get("items", {})
253-
item_type = item_schema.get("type", "any")
254-
is_set = schema.get("uniqueItems") is True
255-
is_tuple = schema.get("minItems") is not None and schema.get("maxItems") is not None
256-
257-
# Handle nested objects/arrays first (most complex case)
258-
if item_type in ("object", "array"):
259-
return [
260-
_deserialize_value_with_schema({"serialization_schema": item_schema, "serialized_data": item})
261-
for item in data
262-
]
263-
264-
# Helper function to deserialize individual items
265-
def deserialize_item(item):
266-
if item_type == "any":
267-
return _deserialize_value(item)
268-
else:
269-
return _deserialize_value({"type": item_type, "data": item})
270-
271-
# Handle different collection types
272-
if is_set:
273-
return {deserialize_item(item) for item in data}
274-
elif is_tuple:
275-
return tuple(deserialize_item(item) for item in data)
216+
if schema_type == "array":
217+
# Deserialize each item
218+
deserialized_items = [
219+
_deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item})
220+
for item in data
221+
]
222+
final_array: Union[list, set, tuple]
223+
# Is a set if uniqueItems is True
224+
if schema.get("uniqueItems") is True:
225+
final_array = set(deserialized_items)
226+
# Is a tuple if minItems and maxItems are set
227+
elif schema.get("minItems") is not None and schema.get("maxItems") is not None:
228+
final_array = tuple(deserialized_items)
276229
else:
277-
return [deserialize_item(item) for item in data]
230+
# Otherwise, it's a list
231+
final_array = list(deserialized_items)
232+
return final_array
278233

279234
# Handle primitive types
280-
elif schema_type in ("null", "boolean", "integer", "number", "string"):
235+
if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values():
281236
return data
282237

283238
# Handle callable functions
284-
elif schema_type == "typing.Callable":
239+
if schema_type == "typing.Callable":
285240
return deserialize_callable(data)
286241

287242
# Handle custom class types
288-
else:
289-
return _deserialize_value({"type": schema_type, "data": data})
243+
return _deserialize_value({"type": schema_type, "data": data})
290244

291245

292-
def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements # noqa: PLR0911
246+
def _deserialize_value(value: dict[str, Any]) -> Any:
293247
"""
294248
Helper function to deserialize values from their envelope format {"type": T, "data": D}.
295249
296-
Handles four cases:
297-
- Typed envelopes: {"type": T, "data": D} where T determines deserialization method
298-
- Plain dicts: recursively deserialize values
299-
- Collections (list/tuple/set): recursively deserialize elements
300-
- Other values: return as-is
250+
This handles:
251+
- Custom classes (with a from_dict method)
252+
- Enums
253+
- Fallback for arbitrary classes (sets attributes on a blank instance)
301254
302255
:param value: The value to deserialize
303-
:returns: The deserialized value
304-
256+
:returns:
257+
The deserialized value
258+
:raises DeserializationError:
259+
If the type cannot be imported or the value is not valid for the type.
305260
"""
306261
# 1) Envelope case
307-
if isinstance(value, dict) and "type" in value and "data" in value:
308-
t = value["type"]
309-
payload = value["data"]
310-
311-
# 1.a) Array
312-
if t == "array":
313-
return [_deserialize_value(child) for child in payload]
314-
315-
# 1.b) Generic object/dict
316-
if t == "object":
317-
return {k: _deserialize_value(v) for k, v in payload.items()}
318-
319-
# 1.c) Primitive
320-
if t in ("null", "boolean", "integer", "number", "string"):
321-
return payload
322-
323-
# 1.d) Callable
324-
if t == "typing.Callable":
325-
return deserialize_callable(payload)
326-
327-
# 1.e) Custom class
328-
cls = import_class_by_name(t)
329-
# first, recursively deserialize the inner payload
330-
deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
331-
# try from_dict
332-
if hasattr(cls, "from_dict") and callable(cls.from_dict):
333-
return cls.from_dict(deserialized_payload)
334-
# fallback: set attributes on a blank instance
335-
instance = cls.__new__(cls)
336-
for attr_name, attr_value in deserialized_payload.items():
337-
setattr(instance, attr_name, attr_value)
338-
return instance
339-
340-
# 2) Plain dict (no envelope) → recurse
341-
if isinstance(value, dict):
342-
return {k: _deserialize_value(v) for k, v in value.items()}
343-
344-
# 3) Collections → recurse
345-
if isinstance(value, (list, tuple, set)):
346-
return type(value)(_deserialize_value(v) for v in value)
347-
348-
# 4) Fallback (shouldn't usually happen with our schema)
349-
return value
262+
value_type = value["type"]
263+
payload = value["data"]
264+
265+
# Custom class where value_type is a qualified class name
266+
cls = import_class_by_name(value_type)
267+
268+
# try from_dict (e.g. Haystack dataclasses and Components)
269+
if hasattr(cls, "from_dict") and callable(cls.from_dict):
270+
return cls.from_dict(payload)
271+
272+
# handle enum types
273+
if issubclass(cls, Enum):
274+
try:
275+
return cls[payload]
276+
except Exception as e:
277+
raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e
278+
279+
# fallback: set attributes on a blank instance
280+
deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
281+
instance = cls.__new__(cls)
282+
for attr_name, attr_value in deserialized_payload.items():
283+
setattr(instance, attr_name, attr_value)
284+
return instance
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Updated our serialization and deserialization of PipelineSnapshots to work with python Enum classes.

0 commit comments

Comments
 (0)