diff --git a/airflow-core/src/airflow/utils/module_loading.py b/airflow-core/src/airflow/utils/module_loading.py index 028b5d31103ec..e0ec74bcb1f03 100644 --- a/airflow-core/src/airflow/utils/module_loading.py +++ b/airflow-core/src/airflow/utils/module_loading.py @@ -18,7 +18,6 @@ from __future__ import annotations import pkgutil -import re from collections.abc import Callable from importlib import import_module from typing import TYPE_CHECKING @@ -27,26 +26,6 @@ from types import ModuleType -def is_valid_dotpath(path: str) -> bool: - """ - Check if a string follows valid dotpath format (ie: 'package.subpackage.module'). - - :param path: String to check - """ - if not isinstance(path, str): - return False - - # Pattern explanation: - # ^ - Start of string - # [a-zA-Z_] - Must start with letter or underscore - # [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores - # (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers - # $ - End of string - pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$" - - return bool(re.match(pattern, path)) - - def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. diff --git a/airflow-core/tests/unit/utils/test_module_loading.py b/airflow-core/tests/unit/utils/test_module_loading.py index 2b0659bc841c5..1f92a004b8fa8 100644 --- a/airflow-core/tests/unit/utils/test_module_loading.py +++ b/airflow-core/tests/unit/utils/test_module_loading.py @@ -19,7 +19,7 @@ import pytest -from airflow.utils.module_loading import import_string, is_valid_dotpath +from airflow.utils.module_loading import import_string class TestModuleImport: @@ -33,20 +33,3 @@ def test_import_string(self): msg = 'Module "airflow.utils" does not define a "nonexistent" attribute' with pytest.raises(ImportError, match=msg): import_string("airflow.utils.nonexistent") - - @pytest.mark.parametrize( - "path, expected", - [ - pytest.param("valid_path", True, id="module_no_dots"), - pytest.param("valid.dot.path", True, id="standard_dotpath"), - pytest.param("package.sub_package.module", True, id="dotpath_with_underscores"), - pytest.param("MyPackage.MyClass", True, id="mixed_case_path"), - pytest.param("invalid..path", False, id="consecutive_dots_fails"), - pytest.param(".invalid.path", False, id="leading_dot_fails"), - pytest.param("invalid.path.", False, id="trailing_dot_fails"), - pytest.param("1invalid.path", False, id="leading_number_fails"), - pytest.param(42, False, id="not_a_string"), - ], - ) - def test_is_valid_dotpath(self, path, expected): - assert is_valid_dotpath(path) == expected diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index cb7cce6445666..612582321bda9 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -126,7 +126,7 @@ def get_uri(self) -> str: def get_hook(self, *, hook_params=None): """Return hook based on conn_type.""" from airflow.providers_manager import ProvidersManager - from airflow.utils.module_loading import import_string + from airflow.sdk.module_loading import import_string hook = ProvidersManager().hooks.get(self.conn_type, None) diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 4b56c9bd94e53..4f6955dfad237 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1287,7 +1287,7 @@ def _run_task(*, ti, task, run_triggerer=False): Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as possible. This function is only meant for the `dag.test` function as a helper function. """ - from airflow.utils.module_loading import import_string + from airflow.sdk.module_loading import import_string from airflow.utils.state import State log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index 8e1e67ae08118..966e2b926a61f 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -24,9 +24,9 @@ from typing import Any, cast from airflow.models.deadline import DeadlineReferenceType, ReferenceModels +from airflow.sdk.module_loading import import_string, is_valid_dotpath from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serde import deserialize, serialize -from airflow.utils.module_loading import import_string, is_valid_dotpath logger = logging.getLogger(__name__) diff --git a/task-sdk/src/airflow/sdk/io/fs.py b/task-sdk/src/airflow/sdk/io/fs.py index 45f07ec46eb39..a49d7c6ffcc53 100644 --- a/task-sdk/src/airflow/sdk/io/fs.py +++ b/task-sdk/src/airflow/sdk/io/fs.py @@ -25,8 +25,8 @@ from fsspec.implementations.local import LocalFileSystem from airflow.providers_manager import ProvidersManager +from airflow.sdk.module_loading import import_string from airflow.stats import Stats -from airflow.utils.module_loading import import_string if TYPE_CHECKING: from fsspec import AbstractFileSystem diff --git a/task-sdk/src/airflow/sdk/io/store.py b/task-sdk/src/airflow/sdk/io/store.py index 05f9b98b136d3..c38fb81bb8227 100644 --- a/task-sdk/src/airflow/sdk/io/store.py +++ b/task-sdk/src/airflow/sdk/io/store.py @@ -78,7 +78,7 @@ def fsid(self) -> str: return f"{self.fs.protocol}-{self.conn_id or 'env'}" def serialize(self): - from airflow.utils.module_loading import qualname + from airflow.sdk.module_loading import qualname return { "protocol": self.protocol, diff --git a/task-sdk/src/airflow/sdk/module_loading.py b/task-sdk/src/airflow/sdk/module_loading.py new file mode 100644 index 0000000000000..6b9d572a7ba8d --- /dev/null +++ b/task-sdk/src/airflow/sdk/module_loading.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from collections.abc import Callable +from importlib import import_module + + +def import_string(dotted_path: str): + """ + Import a dotted module path and return the attribute/class designated by the last name in the path. + + Raise ImportError if the import failed. + """ + # TODO: Add support for nested classes. Currently, it only works for top-level classes. + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError: + raise ImportError(f"{dotted_path} doesn't look like a module path") + + module = import_module(module_path) + + try: + return getattr(module, class_name) + except AttributeError: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') + + +def qualname(o: object | Callable) -> str: + """Convert an attribute/class/function to a string importable by ``import_string``.""" + if callable(o) and hasattr(o, "__module__") and hasattr(o, "__name__"): + return f"{o.__module__}.{o.__name__}" + + cls = o + + if not isinstance(cls, type): # instance or class + cls = type(cls) + + name = cls.__qualname__ + module = cls.__module__ + + if module and module != "__builtin__": + return f"{module}.{name}" + + return name + + +def is_valid_dotpath(path: str) -> bool: + """ + Check if a string follows valid dotpath format (ie: 'package.subpackage.module'). + + :param path: String to check + """ + import re + + if not isinstance(path, str): + return False + + # Pattern explanation: + # ^ - Start of string + # [a-zA-Z_] - Must start with letter or underscore + # [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores + # (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers + # $ - End of string + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$" + + return bool(re.match(pattern, path)) diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py b/task-sdk/tests/task_sdk/definitions/test_connections.py index 6e4d977c6591b..508cb35891d87 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connections.py +++ b/task-sdk/tests/task_sdk/definitions/test_connections.py @@ -39,7 +39,7 @@ def mock_providers_manager(self): with mock.patch("airflow.providers_manager.ProvidersManager") as mock_manager: yield mock_manager - @mock.patch("airflow.utils.module_loading.import_string") + @mock.patch("airflow.sdk.module_loading.import_string") def test_get_hook(self, mock_import_string, mock_providers_manager): """Test that get_hook returns the correct hook instance.""" diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py b/task-sdk/tests/task_sdk/definitions/test_deadline.py index 761a86ae0a9a2..8bb70a7fad2a8 100644 --- a/task-sdk/tests/task_sdk/definitions/test_deadline.py +++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py @@ -28,8 +28,8 @@ DeadlineReference, SyncCallback, ) +from airflow.sdk.module_loading import qualname from airflow.serialization.serde import deserialize, serialize -from airflow.utils.module_loading import qualname UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path" diff --git a/task-sdk/tests/task_sdk/definitions/test_module_loading.py b/task-sdk/tests/task_sdk/definitions/test_module_loading.py new file mode 100644 index 0000000000000..9389e2c8e35f4 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/test_module_loading.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.sdk.module_loading import is_valid_dotpath + + +class TestModuleLoading: + @pytest.mark.parametrize( + "path, expected", + [ + pytest.param("valid_path", True, id="module_no_dots"), + pytest.param("valid.dot.path", True, id="standard_dotpath"), + pytest.param("package.sub_package.module", True, id="dotpath_with_underscores"), + pytest.param("MyPackage.MyClass", True, id="mixed_case_path"), + pytest.param("invalid..path", False, id="consecutive_dots_fails"), + pytest.param(".invalid.path", False, id="leading_dot_fails"), + pytest.param("invalid.path.", False, id="trailing_dot_fails"), + pytest.param("1invalid.path", False, id="leading_number_fails"), + pytest.param(42, False, id="not_a_string"), + ], + ) + def test_is_valid_dotpath(self, path, expected): + assert is_valid_dotpath(path) == expected diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index f1e2b74e47fa6..cacc307fec34d 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -53,6 +53,7 @@ def test_airflow_sdk_no_unexpected_exports(): "log", "exceptions", "timezone", + "module_loading", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}" diff --git a/task-sdk/tests/task_sdk/io/test_path.py b/task-sdk/tests/task_sdk/io/test_path.py index 71af07e855f8b..a85517a1ad4a3 100644 --- a/task-sdk/tests/task_sdk/io/test_path.py +++ b/task-sdk/tests/task_sdk/io/test_path.py @@ -30,7 +30,7 @@ from airflow.sdk import Asset, ObjectStoragePath from airflow.sdk.io import attach from airflow.sdk.io.store import _STORE_CACHE, ObjectStore -from airflow.utils.module_loading import qualname +from airflow.sdk.module_loading import qualname def test_init():