Skip to content

Commit 5babf15

Browse files
authored
Resolve serialization for numpy bool 1.x and 2.x compatibility (#53690)
* fix serialization for np.bool in numpy 2.x get update from main * update serializers importable test case to handle numpy bool compatibility between v1 and v2 * use pytest xfail to capture np.bool import error in numpy version less than 2 * parameterize the serializers importable test, so xfail can be applied individually to numpy bool * add textwrap dedent for message * add np.float32 to serializer and improve test description
1 parent eae6578 commit 5babf15

File tree

3 files changed

+77
-32
lines changed

3 files changed

+77
-32
lines changed

airflow-core/src/airflow/serialization/serializers/numpy.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from typing import TYPE_CHECKING, Any
2121

22-
from airflow.utils.module_loading import import_string, qualname
22+
from airflow.utils.module_loading import qualname
2323

2424
# lazy loading for performance reasons
2525
serializers = [
@@ -31,11 +31,13 @@
3131
"numpy.uint16",
3232
"numpy.uint32",
3333
"numpy.uint64",
34-
"numpy.bool_",
3534
"numpy.float64",
35+
"numpy.float32",
3636
"numpy.float16",
3737
"numpy.complex128",
3838
"numpy.complex64",
39+
"numpy.bool",
40+
"numpy.bool_",
3941
]
4042

4143
if TYPE_CHECKING:
@@ -70,7 +72,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
7072
):
7173
return int(o), *metadata
7274

73-
if isinstance(o, np.bool_):
75+
if hasattr(np, "bool") and isinstance(o, np.bool) or isinstance(o, np.bool_):
7476
return bool(o), *metadata
7577

7678
if isinstance(o, (np.float16, np.float32, np.float64, np.complex64, np.complex128)):
@@ -83,9 +85,4 @@ def deserialize(cls: type, version: int, data: str) -> Any:
8385
if version > __version__:
8486
raise TypeError("serialized version is newer than class version")
8587

86-
allowed_deserialize_classes = [import_string(classname) for classname in deserializers]
87-
88-
if cls not in allowed_deserialize_classes:
89-
raise TypeError(f"unsupported {qualname(cls)} found for numpy deserialization")
90-
9188
return cls(data)

airflow-core/tests/unit/serialization/serializers/test_serializers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def test_numpy_serializers(self):
261261
("klass", "ver", "value", "msg"),
262262
[
263263
(np.int32, 999, 123, r"serialized version is newer"),
264-
(np.float32, 1, 123, r"unsupported numpy\.float32"),
265264
],
266265
)
267266
def test_numpy_deserialize_errors(self, klass, ver, value, msg):

airflow-core/tests/unit/serialization/test_serde.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
import datetime
2020
import enum
21+
import textwrap
2122
from collections import namedtuple
2223
from dataclasses import dataclass
23-
from importlib import import_module
24+
from importlib import import_module, metadata
2425
from typing import ClassVar
2526

2627
import attr
2728
import pytest
29+
from packaging import version
2830
from pydantic import BaseModel
2931

3032
from airflow.sdk.definitions.asset import Asset
@@ -61,6 +63,67 @@ def recalculate_patterns():
6163
_match_regexp.cache_clear()
6264

6365

66+
def generate_serializers_importable_tests():
67+
"""
68+
Generate test cases for `test_serializers_importable_and_str`.
69+
70+
The function iterates through all the modules defined under `airflow.serialization.serializers`. It loads
71+
the import strings defined in the `serializers` from each module, and create a test case to verify that the
72+
serializer is importable.
73+
"""
74+
import airflow.serialization.serializers
75+
76+
NUMPY_VERSION = version.parse(metadata.version("numpy"))
77+
78+
serializer_tests = []
79+
80+
for _, name, _ in iter_namespace(airflow.serialization.serializers):
81+
############################################################
82+
# Handle compatibility / optional dependency at module level
83+
############################################################
84+
# https://github.com/apache/airflow/pull/37320
85+
if name == "airflow.serialization.serializers.iceberg":
86+
try:
87+
import pyiceberg # noqa: F401
88+
except ImportError:
89+
continue
90+
# https://github.com/apache/airflow/pull/38074
91+
if name == "airflow.serialization.serializers.deltalake":
92+
try:
93+
import deltalake # noqa: F401
94+
except ImportError:
95+
continue
96+
mod = import_module(name)
97+
for s in getattr(mod, "serializers", list()):
98+
############################################################
99+
# Handle compatibility issue at serializer level
100+
############################################################
101+
if s == "numpy.bool" and NUMPY_VERSION.major < 2:
102+
reason = textwrap.dedent(f"""\
103+
Current NumPy version: {NUMPY_VERSION}
104+
105+
In NumPy 1.20, `numpy.bool` was deprecated as an alias for the built-in `bool`.
106+
For NumPy versions <= 1.26, attempting to import `numpy.bool` raises an ImportError.
107+
Starting with NumPy 2.0, `numpy.bool` is reintroduced as the NumPy scalar type,
108+
and `numpy.bool_` becomes an alias for `numpy.bool`.
109+
110+
The serializers are loaded lazily at runtime. As a result:
111+
- With NumPy <= 1.26, only `numpy.bool_` is loaded.
112+
- With NumPy >= 2.0, only `numpy.bool` is loaded.
113+
114+
This test case deliberately attempts to import both `numpy.bool` and `numpy.bool_`,
115+
regardless of the installed NumPy version. Therefore, when NumPy <= 1.26 is installed,
116+
importing `numpy.bool` will raise an ImportError.
117+
""")
118+
serializer_tests.append(pytest.param(name, s, marks=pytest.mark.skip(reason=reason)))
119+
else:
120+
serializer_tests.append(pytest.param(name, s))
121+
return serializer_tests
122+
123+
124+
SERIALIZER_TESTS = generate_serializers_importable_tests()
125+
126+
64127
class Z:
65128
__version__: ClassVar[int] = 1
66129

@@ -386,29 +449,15 @@ def test_encode_asset(self):
386449
obj = deserialize(serialize(asset))
387450
assert asset.uri == obj.uri
388451

389-
def test_serializers_importable_and_str(self):
452+
@pytest.mark.parametrize("name, s", SERIALIZER_TESTS)
453+
def test_serializers_importable_and_str(self, name, s):
390454
"""Test if all distributed serializers are lazy loading and can be imported"""
391-
import airflow.serialization.serializers
392-
393-
for _, name, _ in iter_namespace(airflow.serialization.serializers):
394-
if name == "airflow.serialization.serializers.iceberg":
395-
try:
396-
import pyiceberg # noqa: F401
397-
except ImportError:
398-
continue
399-
if name == "airflow.serialization.serializers.deltalake":
400-
try:
401-
import deltalake # noqa: F401
402-
except ImportError:
403-
continue
404-
mod = import_module(name)
405-
for s in getattr(mod, "serializers", list()):
406-
if not isinstance(s, str):
407-
raise TypeError(f"{s} is not of type str. This is required for lazy loading")
408-
try:
409-
import_string(s)
410-
except ImportError:
411-
raise AttributeError(f"{s} cannot be imported (located in {name})")
455+
if not isinstance(s, str):
456+
raise TypeError(f"{s} is not of type str. This is required for lazy loading")
457+
try:
458+
import_string(s)
459+
except ImportError:
460+
raise AttributeError(f"{s} cannot be imported (located in {name})")
412461

413462
def test_stringify(self):
414463
i = V(W(10), ["l1", "l2"], (1, 2), 10)

0 commit comments

Comments
 (0)