diff --git a/airflow-core/src/airflow/io/__init__.py b/airflow-core/src/airflow/io/__init__.py index 6bbea93e59ba3..3b255aacdf82a 100644 --- a/airflow-core/src/airflow/io/__init__.py +++ b/airflow-core/src/airflow/io/__init__.py @@ -16,102 +16,26 @@ # under the License. from __future__ import annotations -import inspect -import logging -from collections.abc import Callable, Mapping -from functools import cache -from typing import ( - TYPE_CHECKING, +from airflow.utils.deprecation_tools import add_deprecated_classes + +add_deprecated_classes( + { + __name__: { + "get_fs": "airflow.sdk.io.get_fs", + "has_fs": "airflow.sdk.io.has_fs", + "attach": "airflow.sdk.io.attach", + "Properties": "airflow.sdk.io.Properties", + "_BUILTIN_SCHEME_TO_FS": "airflow.sdk.io.fs._BUILTIN_SCHEME_TO_FS", + }, + "path": { + "ObjectStoragePath": "airflow.sdk.ObjectStoragePath", + }, + "storage": { + "attach": "airflow.sdk.io.attach", + }, + "typedef": { + "Properties": "airflow.sdk.io.typedef.Properties", + }, + }, + package=__name__, ) - -from fsspec.implementations.local import LocalFileSystem - -from airflow.providers_manager import ProvidersManager -from airflow.stats import Stats -from airflow.utils.module_loading import import_string - -if TYPE_CHECKING: - from fsspec import AbstractFileSystem - - from airflow.io.typedef import Properties - - -log = logging.getLogger(__name__) - - -def _file(_: str | None, storage_options: Properties) -> LocalFileSystem: - return LocalFileSystem(**storage_options) - - -# builtin supported filesystems -_BUILTIN_SCHEME_TO_FS: dict[str, Callable[[str | None, Properties], AbstractFileSystem]] = { - "file": _file, - "local": _file, -} - - -@cache -def _register_filesystems() -> Mapping[ - str, - Callable[[str | None, Properties], AbstractFileSystem] | Callable[[str | None], AbstractFileSystem], -]: - scheme_to_fs = _BUILTIN_SCHEME_TO_FS.copy() - with Stats.timer("airflow.io.load_filesystems") as timer: - manager = ProvidersManager() - for fs_module_name in manager.filesystem_module_names: - fs_module = import_string(fs_module_name) - for scheme in getattr(fs_module, "schemes", []): - if scheme in scheme_to_fs: - log.warning("Overriding scheme %s for %s", scheme, fs_module_name) - - method = getattr(fs_module, "get_fs", None) - if method is None: - raise ImportError(f"Filesystem {fs_module_name} does not have a get_fs method") - scheme_to_fs[scheme] = method - - log.debug("loading filesystems from providers took %.3f seconds", timer.duration) - return scheme_to_fs - - -def get_fs( - scheme: str, conn_id: str | None = None, storage_options: Properties | None = None -) -> AbstractFileSystem: - """ - Get a filesystem by scheme. - - :param scheme: the scheme to get the filesystem for - :return: the filesystem method - :param conn_id: the airflow connection id to use - :param storage_options: the storage options to pass to the filesystem - """ - filesystems = _register_filesystems() - try: - fs = filesystems[scheme] - except KeyError: - raise ValueError(f"No filesystem registered for scheme {scheme}") from None - - options = storage_options or {} - - # MyPy does not recognize dynamic parameters inspection when we call the method, and we have to do - # it for compatibility reasons with already released providers, that's why we need to ignore - # mypy errors here - parameters = inspect.signature(fs).parameters - if len(parameters) == 1: - if options: - raise AttributeError( - f"Filesystem {scheme} does not support storage options, but options were passed." - f"This most likely means that you are using an old version of the provider that does not " - f"support storage options. Please upgrade the provider if possible." - ) - return fs(conn_id) # type: ignore[call-arg] - return fs(conn_id, options) # type: ignore[call-arg] - - -def has_fs(scheme: str) -> bool: - """ - Check if a filesystem is available for a scheme. - - :param scheme: the scheme to check - :return: True if a filesystem is available for the scheme - """ - return scheme in _register_filesystems() diff --git a/airflow-core/src/airflow/io/path.py b/airflow-core/src/airflow/io/path.py deleted file mode 100644 index bc323d0030bc5..0000000000000 --- a/airflow-core/src/airflow/io/path.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from airflow.sdk import ObjectStoragePath - -__all__ = ["ObjectStoragePath"] diff --git a/airflow-core/src/airflow/io/storage.py b/airflow-core/src/airflow/io/storage.py deleted file mode 100644 index 4723e8a15f65a..0000000000000 --- a/airflow-core/src/airflow/io/storage.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from airflow.sdk.io import attach - -__all__ = ["attach"] diff --git a/airflow-core/src/airflow/utils/deprecation_tools.py b/airflow-core/src/airflow/utils/deprecation_tools.py index 3f9977fc96394..e501a758690d5 100644 --- a/airflow-core/src/airflow/utils/deprecation_tools.py +++ b/airflow-core/src/airflow/utils/deprecation_tools.py @@ -84,51 +84,105 @@ def add_deprecated_classes( Add deprecated attribute PEP-563 imports and warnings modules to the package. Works for classes, functions, variables, and other module attributes. - - :param module_imports: imports to use - :param package: package name - :param override_deprecated_classes: override target attributes with deprecated ones. If module + - target attribute is found in the dictionary, it will be displayed in the warning message. + Supports both creating virtual modules and modifying existing modules. + + :param module_imports: imports to use. Format: dict[str, dict[str, str]] + - Keys are module names (creates virtual modules) + - Special key __name__ modifies the current module for direct attribute imports + - Can mix both approaches in a single call + :param package: package name (typically __name__) + :param override_deprecated_classes: override target attributes with deprecated ones. + Format: dict[str, dict[str, str]] matching the structure of module_imports :param extra_message: extra message to display in the warning or import error message - Example: + Examples: + # Create virtual modules (e.g., for removed .py files) add_deprecated_classes( {"basenotifier": {"BaseNotifier": "airflow.sdk.bases.notifier.BaseNotifier"}}, package=__name__, ) - This makes 'from airflow.notifications.basenotifier import BaseNotifier' still work, - even if 'basenotifier.py' was removed, and shows a warning with the new path. - - Wildcard Example: + # Wildcard support - redirect all attributes to new module add_deprecated_classes( {"timezone": {"*": "airflow.sdk.timezone"}}, package=__name__, ) - This makes 'from airflow.utils.timezone import utc' redirect to 'airflow.sdk.timezone.utc', + # Current module direct imports + add_deprecated_classes( + { + __name__: { + "get_fs": "airflow.sdk.io.fs.get_fs", + "has_fs": "airflow.sdk.io.fs.has_fs", + } + }, + package=__name__, + ) + + # Mixed behavior - both current module and submodule attributes + add_deprecated_classes( + { + __name__: { + "get_fs": "airflow.sdk.io.fs.get_fs", + "has_fs": "airflow.sdk.io.fs.has_fs", + "Properties": "airflow.sdk.io.typedef.Properties", + }, + "typedef": { + "Properties": "airflow.sdk.io.typedef.Properties", + } + }, + package=__name__, + ) + + The first example makes 'from airflow.notifications.basenotifier import BaseNotifier' work + even if 'basenotifier.py' was removed. + + The second example makes 'from airflow.utils.timezone import utc' redirect to 'airflow.sdk.timezone.utc', allowing any attribute from the deprecated module to be accessed from the new location. - Note that "add_deprecated_classes method should be called in the `__init__.py` file in the package - where the deprecated classes are located - this way the module `.py` files should be removed and what - remains in the package is just the `__init__.py` file. + The third example makes 'from airflow.io import get_fs' work with direct imports from the current module. - See for example `airflow/decorators/__init__.py` file. + The fourth example handles both direct imports from the current module and submodule imports. """ + # Handle both current module and virtual module deprecations for module_name, imports in module_imports.items(): - full_module_name = f"{package}.{module_name}" - module_type = ModuleType(full_module_name) - if override_deprecated_classes and module_name in override_deprecated_classes: - override_deprecated_classes_for_module = override_deprecated_classes[module_name] + if module_name == package: + # Special case: modify the current module for direct attribute imports + if package not in sys.modules: + raise ValueError(f"Module {package} not found in sys.modules") + + module = sys.modules[package] + + # Create the __getattr__ function for current module + current_override = {} + if override_deprecated_classes and package in override_deprecated_classes: + current_override = override_deprecated_classes[package] + + getattr_func = functools.partial( + getattr_with_deprecation, + imports, + package, + current_override, + extra_message or "", + ) + + # Set the __getattr__ function on the current module + setattr(module, "__getattr__", getattr_func) else: - override_deprecated_classes_for_module = {} - - # Mypy is not able to derive the right function signature https://github.com/python/mypy/issues/2427 - module_type.__getattr__ = functools.partial( # type: ignore[assignment] - getattr_with_deprecation, - imports, - full_module_name, - override_deprecated_classes_for_module, - extra_message or "", - ) - sys.modules.setdefault(full_module_name, module_type) + # Create virtual modules for submodule imports + full_module_name = f"{package}.{module_name}" + module_type = ModuleType(full_module_name) + if override_deprecated_classes and module_name in override_deprecated_classes: + override_deprecated_classes_for_module = override_deprecated_classes[module_name] + else: + override_deprecated_classes_for_module = {} + + # Mypy is not able to derive the right function signature https://github.com/python/mypy/issues/2427 + module_type.__getattr__ = functools.partial( # type: ignore[assignment] + getattr_with_deprecation, + imports, + full_module_name, + override_deprecated_classes_for_module, + extra_message or "", + ) + sys.modules.setdefault(full_module_name, module_type) diff --git a/airflow-core/tests/unit/utils/test_deprecation_tools.py b/airflow-core/tests/unit/utils/test_deprecation_tools.py index adaed45ff45b3..fafcde2364c71 100644 --- a/airflow-core/tests/unit/utils/test_deprecation_tools.py +++ b/airflow-core/tests/unit/utils/test_deprecation_tools.py @@ -252,71 +252,239 @@ def test_getattr_with_deprecation_wildcard_allows_non_dunder_attributes(self, no class TestAddDeprecatedClasses: """Tests for the add_deprecated_classes function.""" - def test_add_deprecated_classes_basic(self): - """Test basic functionality of add_deprecated_classes.""" + @pytest.mark.parametrize( + "test_case,module_imports,override_classes,expected_behavior", + [ + ( + "basic_class_mapping", + {"old_module": {"OldClass": "new.module.NewClass"}}, + None, + "creates_virtual_module", + ), + ( + "wildcard_pattern", + {"timezone": {"*": "airflow.sdk.timezone"}}, + None, + "creates_virtual_module", + ), + ( + "with_override", + {"old_module": {"OldClass": "new.module.NewClass"}}, + {"old_module": {"OldClass": "override.module.OverrideClass"}}, + "creates_virtual_module", + ), + ], + ids=["basic_class_mapping", "wildcard_pattern", "with_override"], + ) + def test_virtual_module_creation(self, test_case, module_imports, override_classes, expected_behavior): + """Test add_deprecated_classes creates virtual modules correctly.""" # Use unique package and module names to avoid conflicts package_name = get_unique_module_name("test_package") - module_name = f"{package_name}.old_module" - - module_imports = {"old_module": {"OldClass": "new.module.NewClass"}} + module_name = f"{package_name}.{next(iter(module_imports.keys()))}" with temporary_module(module_name): - add_deprecated_classes(module_imports, package_name) + add_deprecated_classes(module_imports, package_name, override_classes) # Check that the module was added to sys.modules assert module_name in sys.modules assert isinstance(sys.modules[module_name], ModuleType) assert hasattr(sys.modules[module_name], "__getattr__") - def test_add_deprecated_classes_with_wildcard(self): - """Test add_deprecated_classes with wildcard pattern.""" - # Use unique package and module names to avoid conflicts - package_name = get_unique_module_name("test_package") - module_name = f"{package_name}.timezone" - - module_imports = {"timezone": {"*": "airflow.sdk.timezone"}} - - with temporary_module(module_name): - add_deprecated_classes(module_imports, package_name) - - # Check that the module was added to sys.modules - assert module_name in sys.modules - assert isinstance(sys.modules[module_name], ModuleType) - assert hasattr(sys.modules[module_name], "__getattr__") - - def test_add_deprecated_classes_with_override(self): - """Test add_deprecated_classes with override_deprecated_classes.""" - # Use unique package and module names to avoid conflicts - package_name = get_unique_module_name("test_package") - module_name = f"{package_name}.old_module" - - module_imports = {"old_module": {"OldClass": "new.module.NewClass"}} - - override_deprecated_classes = {"old_module": {"OldClass": "override.module.OverrideClass"}} - - with temporary_module(module_name): - add_deprecated_classes(module_imports, package_name, override_deprecated_classes) - - # Check that the module was added to sys.modules - assert module_name in sys.modules - assert isinstance(sys.modules[module_name], ModuleType) - def test_add_deprecated_classes_doesnt_override_existing(self): """Test that add_deprecated_classes doesn't override existing modules.""" - # Use unique package and module names to avoid conflicts - package_name = get_unique_module_name("test_package") - module_name = f"{package_name}.existing_module" + module_name = get_unique_module_name("existing_module") + full_module_name = f"airflow.test.{module_name}" + + # Create an existing module + existing_module = ModuleType(full_module_name) + existing_module.existing_attr = "existing_value" + sys.modules[full_module_name] = existing_module + + with temporary_module(full_module_name): + # This should not override the existing module + add_deprecated_classes( + {module_name: {"NewClass": "new.module.NewClass"}}, + package="airflow.test", + ) - module_imports = {"existing_module": {"SomeClass": "new.module.SomeClass"}} + # The existing module should still be there + assert sys.modules[full_module_name] == existing_module + assert sys.modules[full_module_name].existing_attr == "existing_value" + + @pytest.mark.parametrize( + "test_case,module_imports,attr_name,target_attr,expected_target_msg,override_classes", + [ + ( + "direct_imports", + { + "get_something": "target.module.get_something", + "another_attr": "target.module.another_attr", + }, + "get_something", + "get_something", + "target.module.get_something", + None, + ), + ( + "with_wildcard", + {"specific_attr": "target.module.specific_attr", "*": "target.module"}, + "any_attribute", + "any_attribute", + "target.module.any_attribute", + None, + ), + ( + "with_override", + {"get_something": "target.module.get_something"}, + "get_something", + "get_something", + "override.module.OverrideClass", + {"get_something": "override.module.OverrideClass"}, + ), + ], + ids=["direct_imports", "with_wildcard", "with_override"], + ) + def test_current_module_deprecation( + self, test_case, module_imports, attr_name, target_attr, expected_target_msg, override_classes + ): + """Test add_deprecated_classes with current module (__name__ key) functionality.""" + module_name = get_unique_module_name(f"{test_case}_module") + full_module_name = f"airflow.test.{module_name}" + + # Create a module to modify + test_module = ModuleType(full_module_name) + sys.modules[full_module_name] = test_module + + with temporary_module(full_module_name): + # Mock the target module and attribute + mock_target_module = mock.MagicMock() + mock_attribute = mock.MagicMock() + setattr(mock_target_module, target_attr, mock_attribute) + + with mock.patch( + "airflow.utils.deprecation_tools.importlib.import_module", return_value=mock_target_module + ): + # Prepare override parameter + override_param = {full_module_name: override_classes} if override_classes else None + + add_deprecated_classes( + {full_module_name: module_imports}, + package=full_module_name, + override_deprecated_classes=override_param, + ) - with temporary_module(module_name): - # Create a mock existing module - existing_module = ModuleType(module_name) - existing_module.existing_attribute = "existing_value" - sys.modules[module_name] = existing_module + # The module should now have a __getattr__ method + assert hasattr(test_module, "__getattr__") + + # Test that accessing the deprecated attribute works + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = getattr(test_module, attr_name) + + assert result == mock_attribute + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert f"{full_module_name}.{attr_name}" in str(w[0].message) + assert expected_target_msg in str(w[0].message) + + def test_add_deprecated_classes_mixed_current_and_virtual_modules(self): + """Test add_deprecated_classes with mixed current module and virtual module imports.""" + base_module_name = get_unique_module_name("mixed_module") + full_module_name = f"airflow.test.{base_module_name}" + virtual_module_name = f"{base_module_name}_virtual" + full_virtual_module_name = f"{full_module_name}.{virtual_module_name}" + + # Create a module to modify + test_module = ModuleType(full_module_name) + sys.modules[full_module_name] = test_module + + with temporary_module(full_module_name), temporary_module(full_virtual_module_name): + # Mock the target modules and attributes + mock_current_module = mock.MagicMock() + mock_current_attr = mock.MagicMock() + mock_current_module.current_attr = mock_current_attr + + mock_virtual_module = mock.MagicMock() + mock_virtual_attr = mock.MagicMock() + mock_virtual_module.VirtualClass = mock_virtual_attr + + def mock_import_module(module_name): + if "current.module" in module_name: + return mock_current_module + if "virtual.module" in module_name: + return mock_virtual_module + raise ImportError(f"Module {module_name} not found") + + with mock.patch( + "airflow.utils.deprecation_tools.importlib.import_module", side_effect=mock_import_module + ): + add_deprecated_classes( + { + full_module_name: { + "current_attr": "current.module.current_attr", + }, + virtual_module_name: { + "VirtualClass": "virtual.module.VirtualClass", + }, + }, + package=full_module_name, + ) + + # Test current module access + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = test_module.current_attr + + assert result == mock_current_attr + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert f"{full_module_name}.current_attr" in str(w[0].message) + + # Test virtual module access + virtual_module = sys.modules[full_virtual_module_name] + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = virtual_module.VirtualClass + + assert result == mock_virtual_attr + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert f"{full_virtual_module_name}.VirtualClass" in str(w[0].message) + + def test_add_deprecated_classes_current_module_not_in_sys_modules(self): + """Test add_deprecated_classes raises error when current module not in sys.modules.""" + nonexistent_module = "nonexistent.module.name" + + with pytest.raises(ValueError, match=f"Module {nonexistent_module} not found in sys.modules"): + add_deprecated_classes( + {nonexistent_module: {"attr": "target.module.attr"}}, + package=nonexistent_module, + ) + + def test_add_deprecated_classes_preserves_existing_module_attributes(self): + """Test that add_deprecated_classes preserves existing module attributes.""" + module_name = get_unique_module_name("preserve_module") + full_module_name = f"airflow.test.{module_name}" + + # Create a module with existing attributes + test_module = ModuleType(full_module_name) + test_module.existing_attr = "existing_value" + test_module.existing_function = lambda: "existing_function_result" + sys.modules[full_module_name] = test_module + + with temporary_module(full_module_name): + add_deprecated_classes( + { + full_module_name: { + "deprecated_attr": "target.module.deprecated_attr", + } + }, + package=full_module_name, + ) - add_deprecated_classes(module_imports, package_name) + # Existing attributes should still be accessible + assert test_module.existing_attr == "existing_value" + assert test_module.existing_function() == "existing_function_result" - # Check that the existing module was not overridden - assert sys.modules[module_name] is existing_module - assert sys.modules[module_name].existing_attribute == "existing_value" + # The module should have __getattr__ for deprecated attributes + assert hasattr(test_module, "__getattr__") diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py b/providers/common/io/tests/unit/common/io/xcom/test_backend.py index e1b242daf4281..72f868c62fcd3 100644 --- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py +++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py @@ -38,7 +38,7 @@ from airflow.sdk.execution_time.comms import XComResult from airflow.sdk.execution_time.xcom import resolve_xcom_backend else: - from airflow.io.path import ObjectStoragePath + from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] from airflow.models.xcom import BaseXCom, resolve_xcom_backend # type: ignore[no-redef] diff --git a/task-sdk/src/airflow/sdk/io/__init__.py b/task-sdk/src/airflow/sdk/io/__init__.py index 4247c5e4acf23..0fec7a7bc037c 100644 --- a/task-sdk/src/airflow/sdk/io/__init__.py +++ b/task-sdk/src/airflow/sdk/io/__init__.py @@ -17,7 +17,9 @@ from __future__ import annotations +from airflow.sdk.io.fs import get_fs, has_fs from airflow.sdk.io.path import ObjectStoragePath from airflow.sdk.io.store import attach +from airflow.sdk.io.typedef import Properties -__all__ = ["ObjectStoragePath", "attach"] +__all__ = ["ObjectStoragePath", "attach", "get_fs", "has_fs", "Properties"] diff --git a/task-sdk/src/airflow/sdk/io/fs.py b/task-sdk/src/airflow/sdk/io/fs.py new file mode 100644 index 0000000000000..45f07ec46eb39 --- /dev/null +++ b/task-sdk/src/airflow/sdk/io/fs.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +import logging +from collections.abc import Callable, Mapping +from functools import cache +from typing import TYPE_CHECKING + +from fsspec.implementations.local import LocalFileSystem + +from airflow.providers_manager import ProvidersManager +from airflow.stats import Stats +from airflow.utils.module_loading import import_string + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + from airflow.sdk.io.typedef import Properties + + +log = logging.getLogger(__name__) + + +def _file(_: str | None, storage_options: Properties) -> LocalFileSystem: + return LocalFileSystem(**storage_options) + + +# builtin supported filesystems +_BUILTIN_SCHEME_TO_FS: dict[str, Callable[[str | None, Properties], AbstractFileSystem]] = { + "file": _file, + "local": _file, +} + + +@cache +def _register_filesystems() -> Mapping[ + str, + Callable[[str | None, Properties], AbstractFileSystem] | Callable[[str | None], AbstractFileSystem], +]: + scheme_to_fs = _BUILTIN_SCHEME_TO_FS.copy() + with Stats.timer("airflow.io.load_filesystems") as timer: + manager = ProvidersManager() + for fs_module_name in manager.filesystem_module_names: + fs_module = import_string(fs_module_name) + for scheme in getattr(fs_module, "schemes", []): + if scheme in scheme_to_fs: + log.warning("Overriding scheme %s for %s", scheme, fs_module_name) + + method = getattr(fs_module, "get_fs", None) + if method is None: + raise ImportError(f"Filesystem {fs_module_name} does not have a get_fs method") + scheme_to_fs[scheme] = method + + log.debug("loading filesystems from providers took %.3f seconds", timer.duration) + return scheme_to_fs + + +def get_fs( + scheme: str, conn_id: str | None = None, storage_options: Properties | None = None +) -> AbstractFileSystem: + """ + Get a filesystem by scheme. + + :param scheme: the scheme to get the filesystem for + :return: the filesystem method + :param conn_id: the airflow connection id to use + :param storage_options: the storage options to pass to the filesystem + """ + filesystems = _register_filesystems() + try: + fs = filesystems[scheme] + except KeyError: + raise ValueError(f"No filesystem registered for scheme {scheme}") from None + + options = storage_options or {} + + # MyPy does not recognize dynamic parameters inspection when we call the method, and we have to do + # it for compatibility reasons with already released providers, that's why we need to ignore + # mypy errors here + parameters = inspect.signature(fs).parameters + if len(parameters) == 1: + if options: + raise AttributeError( + f"Filesystem {scheme} does not support storage options, but options were passed." + f"This most likely means that you are using an old version of the provider that does not " + f"support storage options. Please upgrade the provider if possible." + ) + return fs(conn_id) # type: ignore[call-arg] + return fs(conn_id, options) # type: ignore[call-arg] + + +def has_fs(scheme: str) -> bool: + """ + Check if a filesystem is available for a scheme. + + :param scheme: the scheme to check + :return: True if a filesystem is available for the scheme + """ + return scheme in _register_filesystems() diff --git a/task-sdk/src/airflow/sdk/io/store.py b/task-sdk/src/airflow/sdk/io/store.py index 68c7ad9fbf868..05f9b98b136d3 100644 --- a/task-sdk/src/airflow/sdk/io/store.py +++ b/task-sdk/src/airflow/sdk/io/store.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from fsspec import AbstractFileSystem - from airflow.io.typedef import Properties + from airflow.sdk.io.typedef import Properties class ObjectStore: @@ -57,7 +57,7 @@ def __str__(self): @cached_property def fs(self) -> AbstractFileSystem: - from airflow.io import get_fs + from airflow.sdk.io import get_fs # if the fs is provided in init, the next statement will be ignored return get_fs(self.protocol, self.conn_id) @@ -89,7 +89,7 @@ def serialize(self): @classmethod def deserialize(cls, data: dict[str, str], version: int): - from airflow.io import has_fs + from airflow.sdk.io import has_fs if version > cls.__version__: raise ValueError(f"Cannot deserialize version {version} for {cls.__name__}") diff --git a/airflow-core/src/airflow/io/typedef.py b/task-sdk/src/airflow/sdk/io/typedef.py similarity index 100% rename from airflow-core/src/airflow/io/typedef.py rename to task-sdk/src/airflow/sdk/io/typedef.py diff --git a/task-sdk/tests/task_sdk/io/test_path.py b/task-sdk/tests/task_sdk/io/test_path.py index 1e85b2ee7a288..696d14004319f 100644 --- a/task-sdk/tests/task_sdk/io/test_path.py +++ b/task-sdk/tests/task_sdk/io/test_path.py @@ -325,7 +325,7 @@ def test_serde_store(self): class TestBackwardsCompatibility: @pytest.fixture(autouse=True) def reset(self): - from airflow.io import _register_filesystems + from airflow.sdk.io.fs import _register_filesystems _register_filesystems.cache_clear() yield