diff --git a/airflow-core/tests/unit/serialization/serializers/test_serializers.py b/airflow-core/tests/unit/serialization/serializers/test_serializers.py index 5b8d95f35388e..58cf0ba3cd4ee 100644 --- a/airflow-core/tests/unit/serialization/serializers/test_serializers.py +++ b/airflow-core/tests/unit/serialization/serializers/test_serializers.py @@ -18,6 +18,7 @@ import datetime import decimal +import sys from importlib import metadata from unittest.mock import patch from zoneinfo import ZoneInfo @@ -27,18 +28,39 @@ import pendulum.tz import pytest from dateutil.tz import tzutc +from kubernetes.client import models as k8s from packaging import version from pendulum import DateTime from pendulum.tz.timezone import FixedTimezone, Timezone from airflow.sdk.definitions.param import Param, ParamsDict -from airflow.serialization.serde import DATA, deserialize, serialize +from airflow.serialization.serde import CLASSNAME, DATA, VERSION, _stringify, decode, deserialize, serialize +from airflow.serialization.serializers import builtin +from airflow.utils.module_loading import qualname from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3 +class CustomTZ(datetime.tzinfo): + name = "My/Custom" + + def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta: + return datetime.timedelta(hours=2) + + def dst(self, dt: datetime.datetime | None) -> datetime.timedelta | None: + return datetime.timedelta(0) + + def tzname(self, dt: datetime.datetime | None) -> str | None: + return self.name + + +class NoNameTZ(datetime.tzinfo): + def utcoffset(self, dt): + return datetime.timedelta(hours=2) + + @skip_if_force_lowest_dependencies_marker class TestSerializers: def test_datetime(self): @@ -145,8 +167,6 @@ def test_encode_decimal(self, expr, expected): assert deserialize(serialize(decimal.Decimal(expr))) == decimal.Decimal(expected) def test_encode_k8s_v1pod(self): - from kubernetes.client import models as k8s - pod = k8s.V1Pod( metadata=k8s.V1ObjectMeta( name="foo", @@ -166,12 +186,60 @@ def test_encode_k8s_v1pod(self): "spec": {"containers": [{"image": "bar", "name": "foo"}]}, } + def test_bignum_serialize_non_decimal(self): + from airflow.serialization.serializers.bignum import serialize + + assert serialize(12345) == ("", "", 0, False) + + @pytest.mark.parametrize( + ("klass", "version", "payload", "msg"), + [ + ( + "decimal.Decimal", + 999, + "0", + r"serialized 999 of decimal\.Decimal", # newer version + ), + ( + "wrong.ClassName", + 1, + "0", + r"wrong\.ClassName != .*Decimal", # wrong classname + ), + ], + ) + def test_bignum_deserialize_errors(self, klass, version, payload, msg): + from airflow.serialization.serializers.bignum import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize(klass, version, payload) + def test_numpy(self): i = np.int16(10) e = serialize(i) d = deserialize(e) assert i == d + def test_numpy_serializers(self): + from airflow.serialization.serializers.numpy import serialize + + assert serialize(np.bool_(False)) == (True, "numpy.bool_", 1, True) + assert serialize(np.float32(3.14)) == (float(np.float32(3.14)), "numpy.float32", 1, True) + assert serialize(np.array([1, 2, 3])) == ("", "", 0, False) + + @pytest.mark.parametrize( + ("klass", "ver", "value", "msg"), + [ + ("numpy.int32", 999, 123, r"serialized version is newer"), + ("numpy.float32", 1, 123, r"unsupported numpy\.float32"), + ], + ) + def test_numpy_deserialize_errors(self, klass, ver, value, msg): + from airflow.serialization.serializers.numpy import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize(klass, ver, value) + def test_params(self): i = ParamsDict({"x": Param(default="value", description="there is a value", key="test")}) e = serialize(i) @@ -186,6 +254,24 @@ def test_pandas(self): d = deserialize(e) assert i.equals(d) + def test_pandas_serializers(self): + from airflow.serialization.serializers.pandas import serialize + + assert serialize(123) == ("", "", 0, False) + + @pytest.mark.parametrize( + ("version", "data", "msg"), + [ + (999, "", r"serialized 999 .* > 1"), # version too new + (1, 123, r"wrong data type .*"), # bad payload type + ], + ) + def test_pandas_deserialize_errors(self, version, data, msg): + from airflow.serialization.serializers.pandas import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize("pandas.core.frame.DataFrame", version, data) + def test_iceberg(self): pytest.importorskip("pyiceberg", minversion="2.0.0") from pyiceberg.catalog import Catalog @@ -212,7 +298,7 @@ def test_iceberg(self): mock_load_catalog.assert_called_with("catalog", uri=uri) mock_load_table.assert_called_with((identifier[1], identifier[2])) - def test_deltalake(selfa): + def test_deltalake(self): deltalake = pytest.importorskip("deltalake") with ( @@ -239,6 +325,43 @@ def test_deltalake(selfa): assert i._storage_options == d._storage_options assert d._storage_options is None + def test_deltalake_serialize_deserialize(self): + from airflow.serialization.serializers.deltalake import serialize + + assert serialize(object()) == ("", "", 0, False) + + @pytest.mark.parametrize( + ("klass", "version", "payload", "msg"), + [ + ( + "deltalake.table.DeltaTable", + 999, + {}, + r"serialized version is newer than class version", + ), + ( + "not_a_real_class", + 1, + {}, + r"do not know how to deserialize", + ), + ], + ) + def test_deltalake_deserialize_errors(self, klass, version, payload, msg): + from airflow.serialization.serializers.deltalake import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize(klass, version, payload) + + def test_kubernetes_serializer(self, monkeypatch): + from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator + from airflow.serialization.serializers.kubernetes import serialize + + pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo")) + monkeypatch.setattr(PodGenerator, "serialize_pod", lambda o: (_ for _ in ()).throw(Exception("fail"))) + assert serialize(pod) == ("", "", 0, False) + assert serialize(123) == ("", "", 0, False) + @pytest.mark.skipif(not PENDULUM3, reason="Test case for pendulum~=3") @pytest.mark.parametrize( "ser_value, expected", @@ -386,3 +509,108 @@ def test_pendulum_2_to_3(self, ser_value, expected): def test_pendulum_3_to_2(self, ser_value, expected): """Test deserialize objects in pendulum 2 which serialised in pendulum 3.""" assert deserialize(ser_value) == expected + + def test_timezone_serialize_fixed(self): + from airflow.serialization.serializers.timezone import serialize + + assert serialize(FixedTimezone(0)) == ("UTC", "pendulum.tz.timezone.FixedTimezone", 1, True) + + def test_timezone_serialize_no_name(self): + from airflow.serialization.serializers.timezone import serialize + + assert serialize(NoNameTZ()) == ("", "", 0, False) + + def test_timezone_deserialize_zoneinfo(self): + from airflow.serialization.serializers.timezone import deserialize + + zi = deserialize("backports.zoneinfo.ZoneInfo", 1, "Asia/Taipei") + assert isinstance(zi, ZoneInfo) + assert zi.key == "Asia/Taipei" + + @pytest.mark.parametrize( + "klass, version, data, msg", + [ + ("pendulum.tz.timezone.FixedTimezone", 1, 1.23, "is not of type int or str"), + ("pendulum.tz.timezone.FixedTimezone", 999, "UTC", "serialized 999 .* > 1"), + ], + ) + def test_timezone_deserialize_errors(self, klass, version, data, msg): + from airflow.serialization.serializers.timezone import deserialize + + with pytest.raises(TypeError, match=msg): + deserialize(klass, version, data) + + @pytest.mark.parametrize( + "tz_obj, expected", + [ + (None, None), + (CustomTZ(), "My/Custom"), + (ZoneInfo("Asia/Taipei"), "Asia/Taipei"), + ], + ) + def test_timezone_get_tzinfo_name(self, tz_obj, expected): + from airflow.serialization.serializers.timezone import _get_tzinfo_name + + assert _get_tzinfo_name(tz_obj) == expected + + def test_json_schema_load_dag_schema_dict(self, monkeypatch): + from airflow.exceptions import AirflowException + from airflow.serialization.json_schema import load_dag_schema_dict + + monkeypatch.setattr( + "airflow.serialization.json_schema.pkgutil.get_data", lambda __name__, fname: None + ) + + with pytest.raises(AirflowException) as ctx: + load_dag_schema_dict() + assert "Schema file schema.json does not exists" in str(ctx.value) + + def test_builtin_deserialize_frozenset(self): + res = builtin.deserialize(qualname(frozenset), 1, [13, 14]) + assert isinstance(res, frozenset) + assert res == frozenset({13, 14}) + + def test_builtin_deserialize_version_too_new(self): + with pytest.raises(TypeError, match="serialized version is newer than class version"): + builtin.deserialize(qualname(tuple), 999, [1, 2]) + + @pytest.mark.parametrize( + "func, msg", + [ + (builtin.deserialize, r"do not know how to deserialize"), + (builtin.stringify, r"do not know how to stringify"), + ], + ) + def test_builtin_unknown_type_errors(self, func, msg): + with pytest.raises(TypeError, match=msg): + func("builtins.list", 1, [1, 2]) + + def test_serde_decode_type_error(self): + bad = {CLASSNAME: 123, VERSION: 1, DATA: {}} + with pytest.raises(ValueError, match="cannot decode"): + decode(bad) + + def test_serde_serialize_recursion_limit(self): + depth = sys.getrecursionlimit() - 1 + with pytest.raises(RecursionError, match="maximum recursion depth reached for serialization"): + serialize(object(), depth=depth) + + def test_serde_deserialize_with_type_hint_stringified(self): + fake = {"a": 1, "b": 2, "__version__": 1} + got = deserialize(fake, type_hint=dict, full=False) + assert got == "builtins.dict@version=0(a=1,b=2,__version__=1)" + + def test_serde_deserialize_empty_classname(self): + bad = {CLASSNAME: "", VERSION: 1, DATA: {}} + with pytest.raises(TypeError, match="classname cannot be empty"): + deserialize(bad) + + @pytest.mark.parametrize( + "value, expected", + [ + (123, "dummy@version=1(123)"), + ([1], "dummy@version=1([,1,])"), + ], + ) + def test_serde_stringify_primitives(self, value, expected): + assert _stringify("dummy", 1, value) == expected diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 396bd71a784e8..84b0c02901516 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -25,8 +25,9 @@ import pytest from dateutil import relativedelta from kubernetes.client import models as k8s -from pendulum.tz.timezone import Timezone +from pendulum.tz.timezone import FixedTimezone, Timezone +from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -43,13 +44,23 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey, AssetWatcher +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAll, + AssetAny, + AssetRef, + AssetUniqueKey, + AssetWatcher, +) from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineAlertFields, DeadlineReference from airflow.sdk.definitions.decorators import task from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization, LazyDeserializedDAG, SerializedDAG +from airflow.timetables.base import DataInterval from airflow.triggers.base import BaseTrigger from airflow.utils import timezone from airflow.utils.db import LazySelectSequence @@ -135,6 +146,30 @@ class Test: BaseSerialization.serialize(obj, strict=True) # now raises +def test_validate_schema(): + from airflow.serialization.serialized_objects import BaseSerialization + + with pytest.raises(AirflowException, match="BaseSerialization is not set"): + BaseSerialization.validate_schema({"any": "thing"}) + + BaseSerialization._json_schema = object() + with pytest.raises(TypeError, match="Invalid type: Only dict and str are supported"): + BaseSerialization.validate_schema(123) + + +def test_serde_validate_schema_valid_json(): + from airflow.serialization.serialized_objects import BaseSerialization + + class Test: + def validate(self, obj): + self.obj = obj + + t = Test() + BaseSerialization._json_schema = t + BaseSerialization.validate_schema('{"foo": "bar"}') + assert t.obj == {"foo": "bar"} + + TI = TaskInstance( task=EmptyOperator(task_id="test-task"), run_id="fake_run", @@ -297,6 +332,37 @@ def __len__(self) -> int: DAT.CONNECTION, lambda a, b: a.get_uri() == b.get_uri(), ), + ( + TaskCallbackRequest( + filepath="filepath", + ti=TI, + bundle_name="testing", + bundle_version=None, + ), + DAT.TASK_CALLBACK_REQUEST, + lambda a, b: a.ti == b.ti, + ), + ( + DagCallbackRequest( + filepath="filepath", + dag_id="fake_dag", + run_id="fake_run", + bundle_name="testing", + bundle_version=None, + ), + DAT.DAG_CALLBACK_REQUEST, + lambda a, b: a.dag_id == b.dag_id, + ), + (Asset.ref(name="test"), DAT.ASSET_REF, lambda a, b: a.name == b.name), + ( + DeadlineAlert( + reference=DeadlineReference.DAGRUN_LOGICAL_DATE, + interval=timedelta(), + callback="fake_callable", + ), + None, + None, + ), ( create_outlet_event_accessors( Asset(uri="test", name="test", group="test-group"), {"key": "value"}, [] @@ -563,3 +629,125 @@ def test_get_task_assets(): ("c", asset1), ("d", asset1), ] + + +def test_lazy_dag_run_interval_wrong_dag(): + lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "dag1"}}) + + with pytest.raises(ValueError, match="different DAGs"): + lazy.get_run_data_interval(DAG_RUN) + + +def test_lazy_dag_run_interval_missing_interval(): + lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}}) + + with pytest.raises(ValueError, match="Cannot calculate data interval"): + lazy.get_run_data_interval(DAG_RUN) + + +def test_lazy_dag_run_interval_success(): + run = DAG_RUN + run.data_interval_start = datetime(2025, 1, 1) + run.data_interval_end = datetime(2025, 1, 2) + + lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}}) + interval = lazy.get_run_data_interval(run) + + assert isinstance(interval, DataInterval) + + +def test_hash_property(): + from airflow.models.serialized_dag import SerializedDagModel + + data = {"dag": {"dag_id": "dag1"}} + lazy_serialized_dag = LazyDeserializedDAG(data=data) + assert lazy_serialized_dag.hash == SerializedDagModel.hash(data) + + +@pytest.mark.parametrize( + "payload, expected_cls", + [ + pytest.param( + { + "__type": DAT.ASSET, + "name": "test_asset", + "uri": "test://asset-uri", + "group": "test-group", + "extra": {}, + }, + Asset, + id="asset", + ), + pytest.param( + { + "__type": DAT.ASSET_ALL, + "objects": [ + { + "__type": DAT.ASSET, + "name": "x", + "uri": "test://x", + "group": "g", + "extra": {}, + }, + { + "__type": DAT.ASSET, + "name": "x", + "uri": "test://x", + "group": "g", + "extra": {}, + }, + ], + }, + AssetAll, + id="asset_all", + ), + pytest.param( + { + "__type": DAT.ASSET_ANY, + "objects": [ + { + "__type": DAT.ASSET, + "name": "y", + "uri": "test://y", + "group": "g", + "extra": {}, + } + ], + }, + AssetAny, + id="asset_any", + ), + pytest.param( + {"__type": DAT.ASSET_ALIAS, "name": "alias", "group": "g"}, + AssetAlias, + id="asset_alias", + ), + pytest.param( + {"__type": DAT.ASSET_REF, "name": "ref"}, + AssetRef, + id="asset_ref", + ), + ], +) +def test_serde_decode_asset_condition_success(payload, expected_cls): + from airflow.serialization.serialized_objects import decode_asset_condition + + assert isinstance(decode_asset_condition(payload), expected_cls) + + +def test_serde_decode_asset_condition_unknown_type(): + from airflow.serialization.serialized_objects import decode_asset_condition + + with pytest.raises( + ValueError, + match="deserialization not implemented for DAT 'UNKNOWN_TYPE'", + ): + decode_asset_condition({"__type": "UNKNOWN_TYPE"}) + + +def test_encode_timezone(): + from airflow.serialization.serialized_objects import encode_timezone + + assert encode_timezone(FixedTimezone(0)) == "UTC" + with pytest.raises(ValueError): + encode_timezone(object())