Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions airflow-core/src/airflow/utils/module_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import pkgutil
import re
from collections.abc import Callable
from importlib import import_module
from typing import TYPE_CHECKING
Expand All @@ -27,26 +26,6 @@
from types import ModuleType


def is_valid_dotpath(path: str) -> bool:
"""
Check if a string follows valid dotpath format (ie: 'package.subpackage.module').

:param path: String to check
"""
if not isinstance(path, str):
return False

# Pattern explanation:
# ^ - Start of string
# [a-zA-Z_] - Must start with letter or underscore
# [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores
# (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers
# $ - End of string
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$"

return bool(re.match(pattern, path))


def import_string(dotted_path: str):
"""
Import a dotted module path and return the attribute/class designated by the last name in the path.
Expand Down
19 changes: 1 addition & 18 deletions airflow-core/tests/unit/utils/test_module_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest

from airflow.utils.module_loading import import_string, is_valid_dotpath
from airflow.utils.module_loading import import_string


class TestModuleImport:
Expand All @@ -33,20 +33,3 @@ def test_import_string(self):
msg = 'Module "airflow.utils" does not define a "nonexistent" attribute'
with pytest.raises(ImportError, match=msg):
import_string("airflow.utils.nonexistent")

@pytest.mark.parametrize(
"path, expected",
[
pytest.param("valid_path", True, id="module_no_dots"),
pytest.param("valid.dot.path", True, id="standard_dotpath"),
pytest.param("package.sub_package.module", True, id="dotpath_with_underscores"),
pytest.param("MyPackage.MyClass", True, id="mixed_case_path"),
pytest.param("invalid..path", False, id="consecutive_dots_fails"),
pytest.param(".invalid.path", False, id="leading_dot_fails"),
pytest.param("invalid.path.", False, id="trailing_dot_fails"),
pytest.param("1invalid.path", False, id="leading_number_fails"),
pytest.param(42, False, id="not_a_string"),
],
)
def test_is_valid_dotpath(self, path, expected):
assert is_valid_dotpath(path) == expected
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_uri(self) -> str:
def get_hook(self, *, hook_params=None):
"""Return hook based on conn_type."""
from airflow.providers_manager import ProvidersManager
from airflow.utils.module_loading import import_string
from airflow.sdk.module_loading import import_string

hook = ProvidersManager().hooks.get(self.conn_type, None)

Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def _run_task(*, ti, task, run_triggerer=False):
Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as
possible. This function is only meant for the `dag.test` function as a helper function.
"""
from airflow.utils.module_loading import import_string
from airflow.sdk.module_loading import import_string
from airflow.utils.state import State

log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from typing import Any, cast

from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
from airflow.sdk.module_loading import import_string, is_valid_dotpath
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serde import deserialize, serialize
from airflow.utils.module_loading import import_string, is_valid_dotpath

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/io/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from fsspec.implementations.local import LocalFileSystem

from airflow.providers_manager import ProvidersManager
from airflow.sdk.module_loading import import_string
from airflow.stats import Stats
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from fsspec import AbstractFileSystem
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/io/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def fsid(self) -> str:
return f"{self.fs.protocol}-{self.conn_id or 'env'}"

def serialize(self):
from airflow.utils.module_loading import qualname
from airflow.sdk.module_loading import qualname

return {
"protocol": self.protocol,
Expand Down
82 changes: 82 additions & 0 deletions task-sdk/src/airflow/sdk/module_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from collections.abc import Callable
from importlib import import_module


def import_string(dotted_path: str):
"""
Import a dotted module path and return the attribute/class designated by the last name in the path.
Raise ImportError if the import failed.
"""
# TODO: Add support for nested classes. Currently, it only works for top-level classes.
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError:
raise ImportError(f"{dotted_path} doesn't look like a module path")

module = import_module(module_path)

try:
return getattr(module, class_name)
except AttributeError:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')


def qualname(o: object | Callable) -> str:
"""Convert an attribute/class/function to a string importable by ``import_string``."""
if callable(o) and hasattr(o, "__module__") and hasattr(o, "__name__"):
return f"{o.__module__}.{o.__name__}"

cls = o

if not isinstance(cls, type): # instance or class
cls = type(cls)

name = cls.__qualname__
module = cls.__module__

if module and module != "__builtin__":
return f"{module}.{name}"

return name


def is_valid_dotpath(path: str) -> bool:
"""
Check if a string follows valid dotpath format (ie: 'package.subpackage.module').
:param path: String to check
"""
import re

if not isinstance(path, str):
return False

# Pattern explanation:
# ^ - Start of string
# [a-zA-Z_] - Must start with letter or underscore
# [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores
# (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers
# $ - End of string
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$"

return bool(re.match(pattern, path))
2 changes: 1 addition & 1 deletion task-sdk/tests/task_sdk/definitions/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def mock_providers_manager(self):
with mock.patch("airflow.providers_manager.ProvidersManager") as mock_manager:
yield mock_manager

@mock.patch("airflow.utils.module_loading.import_string")
@mock.patch("airflow.sdk.module_loading.import_string")
def test_get_hook(self, mock_import_string, mock_providers_manager):
"""Test that get_hook returns the correct hook instance."""

Expand Down
2 changes: 1 addition & 1 deletion task-sdk/tests/task_sdk/definitions/test_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
DeadlineReference,
SyncCallback,
)
from airflow.sdk.module_loading import qualname
from airflow.serialization.serde import deserialize, serialize
from airflow.utils.module_loading import qualname

UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path"

Expand Down
40 changes: 40 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_module_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.sdk.module_loading import is_valid_dotpath


class TestModuleLoading:
@pytest.mark.parametrize(
"path, expected",
[
pytest.param("valid_path", True, id="module_no_dots"),
pytest.param("valid.dot.path", True, id="standard_dotpath"),
pytest.param("package.sub_package.module", True, id="dotpath_with_underscores"),
pytest.param("MyPackage.MyClass", True, id="mixed_case_path"),
pytest.param("invalid..path", False, id="consecutive_dots_fails"),
pytest.param(".invalid.path", False, id="leading_dot_fails"),
pytest.param("invalid.path.", False, id="trailing_dot_fails"),
pytest.param("1invalid.path", False, id="leading_number_fails"),
pytest.param(42, False, id="not_a_string"),
],
)
def test_is_valid_dotpath(self, path, expected):
assert is_valid_dotpath(path) == expected
1 change: 1 addition & 0 deletions task-sdk/tests/task_sdk/docs/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_airflow_sdk_no_unexpected_exports():
"log",
"exceptions",
"timezone",
"module_loading",
}
unexpected = actual - public - ignore
assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}"
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/tests/task_sdk/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow.sdk import Asset, ObjectStoragePath
from airflow.sdk.io import attach
from airflow.sdk.io.store import _STORE_CACHE, ObjectStore
from airflow.utils.module_loading import qualname
from airflow.sdk.module_loading import qualname


def test_init():
Expand Down
Loading