Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion airflow-core/src/airflow/serialization/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from dataclasses import is_dataclass
from typing import Any


Expand All @@ -29,4 +30,9 @@ def is_pydantic_model(cls: Any) -> bool:
"""
# __pydantic_fields__ is always present on Pydantic V2 models and is a dict[str, FieldInfo]
# __pydantic_validator__ is an internal validator object, always set after model build
return hasattr(cls, "__pydantic_fields__") and hasattr(cls, "__pydantic_validator__")
# Check if it is not a dataclass to prevent detecting pydantic dataclasses as pydantic models
return (
hasattr(cls, "__pydantic_fields__")
and hasattr(cls, "__pydantic_validator__")
and not is_dataclass(cls)
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datetime
import decimal
from importlib import metadata
from typing import ClassVar
from unittest.mock import patch
from zoneinfo import ZoneInfo

Expand All @@ -33,10 +34,12 @@
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass as pydantic_dataclass

from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify, decode, deserialize, serialize
from airflow.serialization.serializers import builtin
from airflow.utils.module_loading import qualname

from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker

Expand Down Expand Up @@ -68,6 +71,13 @@ class FooBarModel(BaseModel):
foo: str = Field()


@pydantic_dataclass
class PydanticDataclass:
__version__: ClassVar[int] = 1
a: int
b: str


@skip_if_force_lowest_dependencies_marker
class TestSerializers:
def test_datetime(self):
Expand Down Expand Up @@ -423,6 +433,18 @@ def test_pydantic_deserialize_errors(self, klass, version, data, msg):
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)

def test_pydantic_dataclass(self):
orig = PydanticDataclass(a=5, b="SerDe Pydantic Dataclass Test")
serialized = serialize(orig)
assert orig.__version__ == serialized[VERSION]
assert qualname(orig) == serialized[CLASSNAME]
assert serialized[DATA]

decoded = deserialize(serialized)
assert decoded.a == orig.a
assert decoded.b == orig.b
assert type(decoded) is type(orig)

@pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3")
@pytest.mark.parametrize(
"ser_value, expected",
Expand Down