diff --git a/airflow-core/src/airflow/serialization/typing.py b/airflow-core/src/airflow/serialization/typing.py index a6169b23a78d5..35166710b7810 100644 --- a/airflow-core/src/airflow/serialization/typing.py +++ b/airflow-core/src/airflow/serialization/typing.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from dataclasses import is_dataclass from typing import Any @@ -29,4 +30,9 @@ def is_pydantic_model(cls: Any) -> bool: """ # __pydantic_fields__ is always present on Pydantic V2 models and is a dict[str, FieldInfo] # __pydantic_validator__ is an internal validator object, always set after model build - return hasattr(cls, "__pydantic_fields__") and hasattr(cls, "__pydantic_validator__") + # Check if it is not a dataclass to prevent detecting pydantic dataclasses as pydantic models + return ( + hasattr(cls, "__pydantic_fields__") + and hasattr(cls, "__pydantic_validator__") + and not is_dataclass(cls) + ) diff --git a/airflow-core/tests/unit/serialization/serializers/test_serializers.py b/airflow-core/tests/unit/serialization/serializers/test_serializers.py index 3f753150b3804..07d36babba807 100644 --- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py +++ b/airflow-core/tests/unit/serialization/serializers/test_serializers.py @@ -19,6 +19,7 @@ import datetime import decimal from importlib import metadata +from typing import ClassVar from unittest.mock import patch from zoneinfo import ZoneInfo @@ -33,10 +34,12 @@ from pendulum import DateTime from pendulum.tz.timezone import FixedTimezone, Timezone from pydantic import BaseModel, Field +from pydantic.dataclasses import dataclass as pydantic_dataclass from airflow.sdk.definitions.param import Param, ParamsDict from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify, decode, deserialize, serialize from airflow.serialization.serializers import builtin +from airflow.utils.module_loading import qualname from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker @@ -68,6 +71,13 @@ class FooBarModel(BaseModel): foo: str = Field() +@pydantic_dataclass +class PydanticDataclass: + __version__: ClassVar[int] = 1 + a: int + b: str + + @skip_if_force_lowest_dependencies_marker class TestSerializers: def test_datetime(self): @@ -423,6 +433,18 @@ def test_pydantic_deserialize_errors(self, klass, version, data, msg): with pytest.raises(TypeError, match=msg): deserialize(klass, version, data) + def test_pydantic_dataclass(self): + orig = PydanticDataclass(a=5, b="SerDe Pydantic Dataclass Test") + serialized = serialize(orig) + assert orig.__version__ == serialized[VERSION] + assert qualname(orig) == serialized[CLASSNAME] + assert serialized[DATA] + + decoded = deserialize(serialized) + assert decoded.a == orig.a + assert decoded.b == orig.b + assert type(decoded) is type(orig) + @pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3") @pytest.mark.parametrize( "ser_value, expected",