Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions haystack/utils/base_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from enum import Enum
from typing import Any, Union

import pydantic

from haystack import logging
from haystack.core.errors import DeserializationError, SerializationError
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
Expand Down Expand Up @@ -61,7 +63,7 @@ class does not have a `from_dict` method.
return obj_class.from_dict(data["data"])


def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: disable=too-many-return-statements
def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: disable=too-many-return-statements # noqa: PLR0911
"""
Serializes a value into a schema-aware format suitable for storage or transmission.

Expand All @@ -81,8 +83,13 @@ def _serialize_value_with_schema(payload: Any) -> dict[str, Any]: # pylint: dis
- "serialized_data": Contains the actual data in a simplified format.

"""
# Handle pydantic
if isinstance(payload, pydantic.BaseModel):
type_name = generate_qualified_class_name(type(payload))
return {"serialization_schema": {"type": type_name}, "serialized_data": payload.model_dump()}

# Handle dictionary case - iterate through fields
if isinstance(payload, dict):
elif isinstance(payload, dict):
schema: dict[str, Any] = {}
data: dict[str, Any] = {}

Expand Down Expand Up @@ -269,6 +276,15 @@ def _deserialize_value(value: dict[str, Any]) -> Any:
if hasattr(cls, "from_dict") and callable(cls.from_dict):
return cls.from_dict(payload)

# handle pydantic models
if issubclass(cls, pydantic.BaseModel):
try:
return cls.model_validate(payload)
except Exception as e:
raise DeserializationError(
f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'"
) from e
Comment on lines +285 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets show the error caught be the method.

Suggested change
f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'"
) from e
f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'"
f"Error: {e}"
) from e

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are already re-raising from the underlying error so I believe it's best practice to not duplicate the error string. Instead the error is available to see in the traceback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!


# handle enum types
if issubclass(cls, Enum):
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Updated our serialization and deserialization of PipelineSnapshots to work with pydantic BaseModels
35 changes: 35 additions & 0 deletions test/utils/test_base_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from enum import Enum

import pydantic
import pytest

from haystack.core.errors import DeserializationError, SerializationError
Expand All @@ -16,6 +17,11 @@
)


class CustomModel(pydantic.BaseModel):
id: int
name: str


class CustomEnum(Enum):
ONE = "one"
TWO = "two"
Expand Down Expand Up @@ -424,3 +430,32 @@ def test_deserialize_value_with_wrong_value():
_deserialize_value_with_schema(
{"serialization_schema": {"type": "test_base_serialization.CustomEnum"}, "serialized_data": "NOT_VALID"}
)


def test_serialize_and_deserialize_pydantic_model():
model_instance = CustomModel(id=1, name="Test")
serialized = _serialize_value_with_schema(model_instance)
expected_serialized = {
"serialization_schema": {"type": "test_base_serialization.CustomModel"},
"serialized_data": {"id": 1, "name": "Test"},
}
assert serialized == expected_serialized

deserialized = _deserialize_value_with_schema(expected_serialized)
assert isinstance(deserialized, CustomModel)
assert deserialized.id == 1
assert deserialized.name == "Test"


def test_deserialize_pydantic_model_with_invalid_data():
with pytest.raises(
DeserializationError,
match="Failed to deserialize data '{'id': 'not_an_integer', 'name': 'Test'}' into "
"Pydantic model 'test_base_serialization.CustomModel'",
):
_deserialize_value_with_schema(
{
"serialization_schema": {"type": "test_base_serialization.CustomModel"},
"serialized_data": {"id": "not_an_integer", "name": "Test"},
}
)
Loading