Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
45ab19d
improve numpy serializer test coverage
nailo2c Jun 1, 2025
c56efe2
improve pandas, timezone serializer test coverage
nailo2c Jun 1, 2025
aa7b28b
improve k8s serializer test coverage
nailo2c Jun 1, 2025
93cf850
improve serialized_objects test coverage
nailo2c Jun 2, 2025
868633b
improve serialized_objects test coverage #2
nailo2c Jun 5, 2025
a3baaa2
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 5, 2025
4118e3f
fix DeadlineReference import path issue
nailo2c Jun 5, 2025
ea82e6c
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 5, 2025
b27e469
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 5, 2025
e3b6da6
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 6, 2025
b607441
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 7, 2025
49ecf17
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 8, 2025
257c62b
import DeadlineAlert from airflow.sdk.definitions.deadline
nailo2c Jun 8, 2025
a947ec8
refactor test_serializers.py
nailo2c Jun 9, 2025
0030387
refactor test_serializers_objects.py
nailo2c Jun 10, 2025
8baaf2f
fix the 'self' not found issue
nailo2c Jun 10, 2025
8a53cd4
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 10, 2025
106325a
rename test functions for clearer scope
nailo2c Jun 12, 2025
488c333
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 12, 2025
c40cc29
Merge branch 'main' into test-35127-improve_code_coverage_of_serializ…
nailo2c Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 232 additions & 4 deletions airflow-core/tests/unit/serialization/serializers/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
import decimal
import sys
from importlib import metadata
from unittest.mock import patch
from zoneinfo import ZoneInfo
Expand All @@ -27,18 +28,39 @@
import pendulum.tz
import pytest
from dateutil.tz import tzutc
from kubernetes.client import models as k8s
from packaging import version
from pendulum import DateTime
from pendulum.tz.timezone import FixedTimezone, Timezone

from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.serialization.serde import DATA, deserialize, serialize
from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify, decode, deserialize, serialize
from airflow.serialization.serializers import builtin
from airflow.utils.module_loading import qualname

from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker

PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3


class CustomTZ(datetime.tzinfo):
name = "My/Custom"

def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta:
return datetime.timedelta(hours=2)

def dst(self, dt: datetime.datetime | None) -> datetime.timedelta | None:
return datetime.timedelta(0)

def tzname(self, dt: datetime.datetime | None) -> str | None:
return self.name


class NoNameTZ(datetime.tzinfo):
def utcoffset(self, dt):
return datetime.timedelta(hours=2)


@skip_if_force_lowest_dependencies_marker
class TestSerializers:
def test_datetime(self):
Expand Down Expand Up @@ -145,8 +167,6 @@ def test_encode_decimal(self, expr, expected):
assert deserialize(serialize(decimal.Decimal(expr))) == decimal.Decimal(expected)

def test_encode_k8s_v1pod(self):
from kubernetes.client import models as k8s

pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
name="foo",
Expand All @@ -166,12 +186,60 @@ def test_encode_k8s_v1pod(self):
"spec": {"containers": [{"image": "bar", "name": "foo"}]},
}

def test_bignum_serialize_non_decimal(self):
from airflow.serialization.serializers.bignum import serialize

assert serialize(12345) == ("", "", 0, False)

@pytest.mark.parametrize(
("klass", "version", "payload", "msg"),
[
(
"decimal.Decimal",
999,
"0",
r"serialized 999 of decimal\.Decimal", # newer version
),
(
"wrong.ClassName",
1,
"0",
r"wrong\.ClassName != .*Decimal", # wrong classname
),
],
)
def test_bignum_deserialize_errors(self, klass, version, payload, msg):
from airflow.serialization.serializers.bignum import deserialize

with pytest.raises(TypeError, match=msg):
deserialize(klass, version, payload)

def test_numpy(self):
i = np.int16(10)
e = serialize(i)
d = deserialize(e)
assert i == d

def test_numpy_serializers(self):
from airflow.serialization.serializers.numpy import serialize

assert serialize(np.bool_(False)) == (True, "numpy.bool_", 1, True)
assert serialize(np.float32(3.14)) == (float(np.float32(3.14)), "numpy.float32", 1, True)
assert serialize(np.array([1, 2, 3])) == ("", "", 0, False)

@pytest.mark.parametrize(
("klass", "ver", "value", "msg"),
[
("numpy.int32", 999, 123, r"serialized version is newer"),
("numpy.float32", 1, 123, r"unsupported numpy\.float32"),
],
)
def test_numpy_deserialize_errors(self, klass, ver, value, msg):
from airflow.serialization.serializers.numpy import deserialize

with pytest.raises(TypeError, match=msg):
deserialize(klass, ver, value)

def test_params(self):
i = ParamsDict({"x": Param(default="value", description="there is a value", key="test")})
e = serialize(i)
Expand All @@ -186,6 +254,24 @@ def test_pandas(self):
d = deserialize(e)
assert i.equals(d)

def test_pandas_serializers(self):
from airflow.serialization.serializers.pandas import serialize

assert serialize(123) == ("", "", 0, False)

@pytest.mark.parametrize(
("version", "data", "msg"),
[
(999, "", r"serialized 999 .* > 1"), # version too new
(1, 123, r"wrong data type .*<class 'int'>"), # bad payload type
],
)
def test_pandas_deserialize_errors(self, version, data, msg):
from airflow.serialization.serializers.pandas import deserialize

with pytest.raises(TypeError, match=msg):
deserialize("pandas.core.frame.DataFrame", version, data)

def test_iceberg(self):
pytest.importorskip("pyiceberg", minversion="2.0.0")
from pyiceberg.catalog import Catalog
Expand All @@ -212,7 +298,7 @@ def test_iceberg(self):
mock_load_catalog.assert_called_with("catalog", uri=uri)
mock_load_table.assert_called_with((identifier[1], identifier[2]))

def test_deltalake(selfa):
def test_deltalake(self):
deltalake = pytest.importorskip("deltalake")

with (
Expand All @@ -239,6 +325,43 @@ def test_deltalake(selfa):
assert i._storage_options == d._storage_options
assert d._storage_options is None

def test_deltalake_serialize_deserialize(self):
from airflow.serialization.serializers.deltalake import serialize

assert serialize(object()) == ("", "", 0, False)

@pytest.mark.parametrize(
("klass", "version", "payload", "msg"),
[
(
"deltalake.table.DeltaTable",
999,
{},
r"serialized version is newer than class version",
),
(
"not_a_real_class",
1,
{},
r"do not know how to deserialize",
),
],
)
def test_deltalake_deserialize_errors(self, klass, version, payload, msg):
from airflow.serialization.serializers.deltalake import deserialize

with pytest.raises(TypeError, match=msg):
deserialize(klass, version, payload)

def test_kubernetes_serializer(self, monkeypatch):
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.serialization.serializers.kubernetes import serialize

pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo"))
monkeypatch.setattr(PodGenerator, "serialize_pod", lambda o: (_ for _ in ()).throw(Exception("fail")))
assert serialize(pod) == ("", "", 0, False)
assert serialize(123) == ("", "", 0, False)

@pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3")
@pytest.mark.parametrize(
"ser_value, expected",
Expand Down Expand Up @@ -386,3 +509,108 @@ def test_pendulum_2_to_3(self, ser_value, expected):
def test_pendulum_3_to_2(self, ser_value, expected):
"""Test deserialize objects in pendulum 2 which serialised in pendulum 3."""
assert deserialize(ser_value) == expected

def test_timezone_serialize_fixed(self):
from airflow.serialization.serializers.timezone import serialize

assert serialize(FixedTimezone(0)) == ("UTC", "pendulum.tz.timezone.FixedTimezone", 1, True)

def test_timezone_serialize_no_name(self):
from airflow.serialization.serializers.timezone import serialize

assert serialize(NoNameTZ()) == ("", "", 0, False)

def test_timezone_deserialize_zoneinfo(self):
from airflow.serialization.serializers.timezone import deserialize

zi = deserialize("backports.zoneinfo.ZoneInfo", 1, "Asia/Taipei")
assert isinstance(zi, ZoneInfo)
assert zi.key == "Asia/Taipei"

@pytest.mark.parametrize(
"klass, version, data, msg",
[
("pendulum.tz.timezone.FixedTimezone", 1, 1.23, "is not of type int or str"),
("pendulum.tz.timezone.FixedTimezone", 999, "UTC", "serialized 999 .* > 1"),
],
)
def test_timezone_deserialize_errors(self, klass, version, data, msg):
from airflow.serialization.serializers.timezone import deserialize

with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)

@pytest.mark.parametrize(
"tz_obj, expected",
[
(None, None),
(CustomTZ(), "My/Custom"),
(ZoneInfo("Asia/Taipei"), "Asia/Taipei"),
],
)
def test_timezone_get_tzinfo_name(self, tz_obj, expected):
from airflow.serialization.serializers.timezone import _get_tzinfo_name

assert _get_tzinfo_name(tz_obj) == expected

def test_json_schema_load_dag_schema_dict(self, monkeypatch):
from airflow.exceptions import AirflowException
from airflow.serialization.json_schema import load_dag_schema_dict

monkeypatch.setattr(
"airflow.serialization.json_schema.pkgutil.get_data", lambda __name__, fname: None
)

with pytest.raises(AirflowException) as ctx:
load_dag_schema_dict()
assert "Schema file schema.json does not exists" in str(ctx.value)

def test_builtin_deserialize_frozenset(self):
res = builtin.deserialize(qualname(frozenset), 1, [13, 14])
assert isinstance(res, frozenset)
assert res == frozenset({13, 14})

def test_builtin_deserialize_version_too_new(self):
with pytest.raises(TypeError, match="serialized version is newer than class version"):
builtin.deserialize(qualname(tuple), 999, [1, 2])

@pytest.mark.parametrize(
"func, msg",
[
(builtin.deserialize, r"do not know how to deserialize"),
(builtin.stringify, r"do not know how to stringify"),
],
)
def test_builtin_unknown_type_errors(self, func, msg):
with pytest.raises(TypeError, match=msg):
func("builtins.list", 1, [1, 2])

def test_serde_decode_type_error(self):
bad = {CLASSNAME: 123, VERSION: 1, DATA: {}}
with pytest.raises(ValueError, match="cannot decode"):
decode(bad)

def test_serde_serialize_recursion_limit(self):
depth = sys.getrecursionlimit() - 1
with pytest.raises(RecursionError, match="maximum recursion depth reached for serialization"):
serialize(object(), depth=depth)

def test_serde_deserialize_with_type_hint_stringified(self):
fake = {"a": 1, "b": 2, "__version__": 1}
got = deserialize(fake, type_hint=dict, full=False)
assert got == "builtins.dict@version=0(a=1,b=2,__version__=1)"

def test_serde_deserialize_empty_classname(self):
bad = {CLASSNAME: "", VERSION: 1, DATA: {}}
with pytest.raises(TypeError, match="classname cannot be empty"):
deserialize(bad)

@pytest.mark.parametrize(
"value, expected",
[
(123, "dummy@version=1(123)"),
([1], "dummy@version=1([,1,])"),
],
)
def test_serde_stringify_primitives(self, value, expected):
assert _stringify("dummy", 1, value) == expected
Loading