diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46f22f47cfda6..dacc21da6f2a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -596,6 +596,7 @@ repos: ^airflow-core/src/airflow/utils/db\.py$| ^airflow-core/src/airflow/utils/trigger_rule\.py$| ^airflow-core/tests/| + ^task-sdk/tests/| ^.*changelog\.(rst|txt)$| ^.*CHANGELOG\.(rst|txt)$| ^chart/values.schema\.json$| diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 73ee1cd450831..0eced80cc15b8 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -240,6 +240,7 @@ exclude = [ "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/_shared/timezones" "../shared/listeners/src/airflow_shared/listeners" = "src/airflow/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/_shared/plugins_manager" +"../shared/providers_discovery/src/airflow_shared/providers_discovery" = "src/airflow/_shared/providers_discovery" [tool.hatch.build.targets.custom] path = "./hatch_build.py" @@ -317,4 +318,5 @@ shared_distributions = [ "apache-airflow-shared-secrets-masker", "apache-airflow-shared-timezones", "apache-airflow-shared-plugins-manager", + "apache-airflow-shared-providers-discovery", ] diff --git a/airflow-core/src/airflow/_shared/providers_discovery b/airflow-core/src/airflow/_shared/providers_discovery new file mode 120000 index 0000000000000..818cea30f3372 --- /dev/null +++ b/airflow-core/src/airflow/_shared/providers_discovery @@ -0,0 +1 @@ +../../../../shared/providers_discovery/src/airflow_shared/providers_discovery \ No newline at end of file diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 5dc70d0fc9e56..956f52857bbde 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -33,9 +33,9 @@ from time import perf_counter from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, TypeVar, cast -from packaging.utils import canonicalize_name - -from airflow._shared.module_loading import entry_points_with_dist, import_string +from airflow import DeprecatedImportWarning +from airflow._shared.module_loading import import_string +from airflow._shared.providers_discovery import discover_all_providers_from_packages from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.utils.log.logging_mixin import LoggingMixin @@ -438,6 +438,33 @@ def __init__(self): self._plugins_set: set[PluginInfo] = set() self._init_airflow_core_hooks() + self._runtime_manager = None + + def __getattribute__(self, name: str): + # Hacky but does the trick for now + runtime_properties = { + "hooks", + "taskflow_decorators", + "filesystem_module_names", + "asset_factories", + "asset_uri_handlers", + "asset_to_openlineage_converters", + } + + if name in runtime_properties: + warnings.warn( + f"ProvidersManager.{name} is deprecated. Use ProvidersManagerTaskRuntime.{name} from task-sdk instead.", + DeprecatedImportWarning, + stacklevel=2, + ) + if object.__getattribute__(self, "_runtime_manager") is None: + from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + + object.__setattr__(self, "_runtime_manager", ProvidersManagerTaskRuntime()) + return getattr(object.__getattribute__(self, "_runtime_manager"), name) + + return object.__getattribute__(self, name) + def _init_airflow_core_hooks(self): """Initialize the hooks dict with default hooks from Airflow core.""" core_dummy_hooks = { @@ -472,7 +499,7 @@ def initialize_providers_list(self): # Development purpose. In production provider.yaml files are not present in the 'airflow" directory # So there is no risk we are going to override package provider accidentally. This can only happen # in case of local development - self._discover_all_providers_from_packages() + discover_all_providers_from_packages(self._provider_dict, self._provider_schema_validator) self._verify_all_providers_all_compatible() self._provider_dict = dict(sorted(self._provider_dict.items())) @@ -607,57 +634,6 @@ def initialize_providers_cli_command(self): self.initialize_providers_list() self._discover_cli_command() - def _discover_all_providers_from_packages(self) -> None: - """ - Discover all providers by scanning packages installed. - - The list of providers should be returned via the 'apache_airflow_provider' - entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json' - schema. Note that the schema is different at runtime than provider.yaml.schema.json. - The development version of provider schema is more strict and changes together with - the code. The runtime version is more relaxed (allows for additional properties) - and verifies only the subset of fields that are needed at runtime. - """ - for entry_point, dist in entry_points_with_dist("apache_airflow_provider"): - if not dist.metadata: - continue - package_name = canonicalize_name(dist.metadata["name"]) - if package_name in self._provider_dict: - continue - log.debug("Loading %s from package %s", entry_point, package_name) - version = dist.version - provider_info = entry_point.load()() - self._provider_schema_validator.validate(provider_info) - provider_info_package_name = provider_info["package-name"] - if package_name != provider_info_package_name: - raise ValueError( - f"The package '{package_name}' from packaging information " - f"{provider_info_package_name} do not match. Please make sure they are aligned" - ) - - # issue-59576: Retrieve the project.urls.documentation from dist.metadata - project_urls = dist.metadata.get_all("Project-URL") - documentation_url: str | None = None - - if project_urls: - for entry in project_urls: - if "," in entry: - name, url = entry.split(",") - if name.strip().lower() == "documentation": - documentation_url = url - break - - provider_info["documentation-url"] = documentation_url - - if package_name not in self._provider_dict: - self._provider_dict[package_name] = ProviderInfo(version, provider_info) - else: - log.warning( - "The provider for package '%s' could not be registered from because providers for that " - "package name have already been registered", - package_name, - ) - def _discover_hooks_from_connection_types( self, hook_class_names_registered: set[str], diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index 1e579bfdd955d..7d7cc0507dd48 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json import logging import re import sys @@ -25,23 +24,19 @@ from typing import TYPE_CHECKING PY313 = sys.version_info >= (3, 13) -import warnings from unittest.mock import patch import pytest -from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers_manager import ( DialectInfo, - HookClassProvider, LazyDictWithCache, PluginInfo, ProviderInfo, ProvidersManager, ) -from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker, skip_if_not_on_main -from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH +from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker if TYPE_CHECKING: from unittest.mock import MagicMock @@ -52,11 +47,14 @@ def test_cleanup_providers_manager(cleanup_providers_manager): """Check the cleanup provider manager functionality.""" provider_manager = ProvidersManager() - assert isinstance(provider_manager.hooks, LazyDictWithCache) - hooks = provider_manager.hooks + assert isinstance(provider_manager.providers, dict) + providers = provider_manager.providers + assert len(providers) > 0 + ProvidersManager()._cleanup() - assert not len(hooks) - assert ProvidersManager().hooks is hooks + + # even after cleanup the singleton should return same instance but internal state is reset + assert len(ProvidersManager().providers) > 0 @skip_if_force_lowest_dependencies_marker @@ -98,104 +96,6 @@ def test_providers_are_loaded(self): assert len(provider_list) > 65 assert self._caplog.records == [] - def test_hooks_deprecation_warnings_generated(self): - providers_manager = ProvidersManager() - providers_manager._provider_dict["test-package"] = ProviderInfo( - version="0.0.1", - data={"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"]}, - ) - with pytest.warns(expected_warning=DeprecationWarning, match="hook-class-names") as warning_records: - providers_manager._discover_hooks() - assert warning_records - - def test_hooks_deprecation_warnings_not_generated(self): - with warnings.catch_warnings(record=True) as warning_records: - providers_manager = ProvidersManager() - providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( - version="0.0.1", - data={ - "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], - "connection-types": [ - { - "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", - "connection-type": "sftp", - } - ], - }, - ) - providers_manager._discover_hooks() - assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] - - def test_warning_logs_generated(self): - providers_manager = ProvidersManager() - providers_manager._hooks_lazy_dict = LazyDictWithCache() - with self._caplog.at_level(logging.WARNING): - providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( - version="0.0.1", - data={ - "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], - "connection-types": [ - { - "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", - "connection-type": "wrong-connection-type", - } - ], - }, - ) - providers_manager._discover_hooks() - _ = providers_manager._hooks_lazy_dict["wrong-connection-type"] - assert len(self._caplog.entries) == 1 - assert "Inconsistency!" in self._caplog[0]["event"] - assert "sftp" not in providers_manager.hooks - - def test_warning_logs_not_generated(self): - with self._caplog.at_level(logging.WARNING): - providers_manager = ProvidersManager() - providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( - version="0.0.1", - data={ - "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], - "connection-types": [ - { - "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", - "connection-type": "sftp", - } - ], - }, - ) - providers_manager._discover_hooks() - _ = providers_manager._hooks_lazy_dict["sftp"] - assert not self._caplog.records - assert "sftp" in providers_manager.hooks - - def test_already_registered_conn_type_in_provide(self): - with self._caplog.at_level(logging.WARNING): - providers_manager = ProvidersManager() - providers_manager._provider_dict["apache-airflow-providers-dummy"] = ProviderInfo( - version="0.0.1", - data={ - "connection-types": [ - { - "hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook", - "connection-type": "dummy", - }, - { - "hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook2", - "connection-type": "dummy", - }, - ], - }, - ) - providers_manager._discover_hooks() - _ = providers_manager._hooks_lazy_dict["dummy"] - assert len(self._caplog.records) == 1 - msg = self._caplog.messages[0] - assert msg.startswith("The connection type 'dummy' is already registered") - assert ( - "different class names: 'airflow.providers.dummy.hooks.dummy.DummyHook'" - " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'." - ) in msg - def test_providers_manager_register_plugins(self): providers_manager = ProvidersManager() providers_manager._provider_dict = LazyDictWithCache() @@ -243,61 +143,6 @@ def test_providers_manager_register_dialects(self): ), ) - def test_hooks(self): - with warnings.catch_warnings(record=True) as warning_records: - with self._caplog.at_level(logging.WARNING): - provider_manager = ProvidersManager() - connections_list = list(provider_manager.hooks.keys()) - assert len(connections_list) > 60 - if len(self._caplog.records) != 0: - for record in self._caplog.records: - print(record.message, file=sys.stderr) - print(record.exc_info, file=sys.stderr) - raise AssertionError("There are warnings generated during hook imports. Please fix them") - assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] - - @skip_if_not_on_main - @pytest.mark.execution_timeout(150) - def test_hook_values(self): - provider_dependencies = json.loads( - (AIRFLOW_ROOT_PATH / "generated" / "provider_dependencies.json").read_text() - ) - python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - excluded_providers: list[str] = [] - for provider_name, provider_info in provider_dependencies.items(): - if python_version in provider_info.get("excluded-python-versions", []): - excluded_providers.append(f"apache-airflow-providers-{provider_name.replace('.', '-')}") - with warnings.catch_warnings(record=True) as warning_records: - with self._caplog.at_level(logging.WARNING): - provider_manager = ProvidersManager() - connections_list = list(provider_manager.hooks.values()) - assert len(connections_list) > 60 - if len(self._caplog.records) != 0: - real_warning_count = 0 - for record in self._caplog.entries: - # When there is error importing provider that is excluded the provider name is in the message - if any(excluded_provider in record["event"] for excluded_provider in excluded_providers): - continue - print(record["event"], file=sys.stderr) - print(record.get("exc_info"), file=sys.stderr) - real_warning_count += 1 - if real_warning_count: - if PY313: - only_ydb_and_yandexcloud_warnings = True - for record in warning_records: - if "ydb" in str(record.message) or "yandexcloud" in str(record.message): - continue - only_ydb_and_yandexcloud_warnings = False - if only_ydb_and_yandexcloud_warnings: - print( - "Only warnings from ydb and yandexcloud providers are generated, " - "which is expected in Python 3.13+", - file=sys.stderr, - ) - return - raise AssertionError("There are warnings generated during hook imports. Please fix them") - assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] - def test_connection_form_widgets(self): provider_manager = ProvidersManager() connections_form_widgets = list(provider_manager.connection_form_widgets.keys()) @@ -390,34 +235,6 @@ def test_dialects(self): assert len(dialect_class_names) == 3 assert dialect_class_names == ["default", "mssql", "postgresql"] - @patch("airflow.providers_manager.import_string") - def test_optional_feature_no_warning(self, mock_importlib_import_string): - with self._caplog.at_level(logging.WARNING): - mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException() - providers_manager = ProvidersManager() - providers_manager._hook_provider_dict["test_connection"] = HookClassProvider( - package_name="test_package", hook_class_name="HookClass" - ) - providers_manager._import_hook( - hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection" - ) - assert self._caplog.messages == [] - - @patch("airflow.providers_manager.import_string") - def test_optional_feature_debug(self, mock_importlib_import_string): - with self._caplog.at_level(logging.INFO): - mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException() - providers_manager = ProvidersManager() - providers_manager._hook_provider_dict["test_connection"] = HookClassProvider( - package_name="test_package", hook_class_name="HookClass" - ) - providers_manager._import_hook( - hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection" - ) - assert self._caplog.messages == [ - "Optional provider feature disabled when importing 'HookClass' from 'test_package' package" - ] - class TestWithoutCheckProviderManager: @patch("airflow.providers_manager.import_string") @@ -456,93 +273,3 @@ def test_executors_without_check_property_should_not_called_import_string( mock_correctness_check.assert_not_called() assert providers_manager._executor_without_check_set == result - - -@pytest.mark.parametrize( - ("value", "expected_outputs"), - [ - ("a", "a"), - (1, 1), - (None, None), - (lambda: 0, 0), - (lambda: None, None), - (lambda: "z", "z"), - ], -) -def test_lazy_cache_dict_resolving(value, expected_outputs): - lazy_cache_dict = LazyDictWithCache() - lazy_cache_dict["key"] = value - assert lazy_cache_dict["key"] == expected_outputs - # Retrieve it again to see if it is correctly returned again - assert lazy_cache_dict["key"] == expected_outputs - - -def test_lazy_cache_dict_raises_error(): - def raise_method(): - raise RuntimeError("test") - - lazy_cache_dict = LazyDictWithCache() - lazy_cache_dict["key"] = raise_method - with pytest.raises(RuntimeError, match="test"): - _ = lazy_cache_dict["key"] - - -def test_lazy_cache_dict_del_item(): - lazy_cache_dict = LazyDictWithCache() - - def answer(): - return 42 - - lazy_cache_dict["spam"] = answer - assert "spam" in lazy_cache_dict._raw_dict - assert "spam" not in lazy_cache_dict._resolved # Not resoled yet - assert lazy_cache_dict["spam"] == 42 - assert "spam" in lazy_cache_dict._resolved - del lazy_cache_dict["spam"] - assert "spam" not in lazy_cache_dict._raw_dict - assert "spam" not in lazy_cache_dict._resolved - - lazy_cache_dict["foo"] = answer - assert lazy_cache_dict["foo"] == 42 - assert "foo" in lazy_cache_dict._resolved - # Emulate some mess in data, e.g. value from `_raw_dict` deleted but not from `_resolved` - del lazy_cache_dict._raw_dict["foo"] - assert "foo" in lazy_cache_dict._resolved - with pytest.raises(KeyError): - # Error expected here, but we still expect to remove also record into `resolved` - del lazy_cache_dict["foo"] - assert "foo" not in lazy_cache_dict._resolved - - lazy_cache_dict["baz"] = answer - # Key in `_resolved` not created yet - assert "baz" in lazy_cache_dict._raw_dict - assert "baz" not in lazy_cache_dict._resolved - del lazy_cache_dict._raw_dict["baz"] - assert "baz" not in lazy_cache_dict._raw_dict - assert "baz" not in lazy_cache_dict._resolved - - -def test_lazy_cache_dict_clear(): - def answer(): - return 42 - - lazy_cache_dict = LazyDictWithCache() - assert len(lazy_cache_dict) == 0 - lazy_cache_dict["spam"] = answer - lazy_cache_dict["foo"] = answer - lazy_cache_dict["baz"] = answer - - assert len(lazy_cache_dict) == 3 - assert len(lazy_cache_dict._raw_dict) == 3 - assert not lazy_cache_dict._resolved - assert lazy_cache_dict["spam"] == 42 - assert len(lazy_cache_dict._resolved) == 1 - # Emulate some mess in data, contain some data into the `_resolved` - lazy_cache_dict._resolved.add("biz") - assert len(lazy_cache_dict) == 3 - assert len(lazy_cache_dict._resolved) == 2 - # And finally cleanup everything - lazy_cache_dict.clear() - assert len(lazy_cache_dict) == 0 - assert not lazy_cache_dict._raw_dict - assert not lazy_cache_dict._resolved diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 9dd6e782193f1..bc18c94f17bff 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1890,6 +1890,17 @@ def cleanup_providers_manager(): ProvidersManager().initialize_providers_configuration() +@pytest.fixture +def cleanup_providers_manager_runtime(): + from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + + ProvidersManagerTaskRuntime()._cleanup() + try: + yield + finally: + ProvidersManagerTaskRuntime()._cleanup() + + @pytest.fixture(autouse=True) def _disable_redact(request: pytest.FixtureRequest, mocker): """Disable redacted text in tests, except specific.""" diff --git a/providers/yandex/tests/unit/yandex/utils/test_user_agent.py b/providers/yandex/tests/unit/yandex/utils/test_user_agent.py index 0fd20f85d1187..b443b77e4537e 100644 --- a/providers/yandex/tests/unit/yandex/utils/test_user_agent.py +++ b/providers/yandex/tests/unit/yandex/utils/test_user_agent.py @@ -47,10 +47,9 @@ def test_provider_user_agent(): assert user_agent_prefix in user_agent -@mock.patch("airflow.providers_manager.ProvidersManager.hooks") -def test_provider_user_agent_hook_not_exists(mock_hooks): - mock_hooks.return_value = [] +def test_provider_user_agent_hook_not_exists(): + with mock.patch("airflow.providers_manager.ProvidersManager") as mock_pm_class: + mock_pm_class.return_value.hooks = {} - user_agent = provider_user_agent() - - assert user_agent is None + user_agent = provider_user_agent() + assert user_agent is None diff --git a/pyproject.toml b/pyproject.toml index efa0a1a2a7c3e..9d8d15bea2e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1296,6 +1296,7 @@ dev = [ "apache-airflow-shared-module-loading", "apache-airflow-shared-observability", "apache-airflow-shared-plugins-manager", + "apache-airflow-shared-providers-discovery", "apache-airflow-shared-secrets-backend", "apache-airflow-shared-secrets-masker", "apache-airflow-shared-timezones", @@ -1354,6 +1355,7 @@ apache-airflow-shared-logging = { workspace = true } apache-airflow-shared-module-loading = { workspace = true } apache-airflow-shared-observability = { workspace = true } apache-airflow-shared-plugins-manager = { workspace = true } +apache-airflow-shared-providers-discovery = { workspace = true } apache-airflow-shared-secrets-backend = { workspace = true } apache-airflow-shared-secrets-masker = { workspace = true } apache-airflow-shared-timezones = { workspace = true } @@ -1481,6 +1483,7 @@ members = [ "shared/module_loading", "shared/observability", "shared/plugins_manager", + "shared/providers_discovery", "shared/secrets_backend", "shared/secrets_masker", "shared/timezones", diff --git a/shared/providers_discovery/pyproject.toml b/shared/providers_discovery/pyproject.toml new file mode 100644 index 0000000000000..1da7f4045292a --- /dev/null +++ b/shared/providers_discovery/pyproject.toml @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[project] +name = "apache-airflow-shared-providers-discovery" +description = "Shared provider discovery code for Airflow distributions" +version = "0.0" +classifiers = [ + "Private :: Do Not Upload", +] + +dependencies = [ + "packaging", + "pendulum>=3.1.0", + "jsonschema", + "structlog>=25.4.0", + "pygtrie>=2.5.0", + "methodtools>=0.4.7", + 'importlib_metadata>=6.5;python_version<"3.12"', + 'importlib_metadata>=7.0;python_version>="3.12"', +] + +[dependency-groups] +dev = [ + "apache-airflow-devel-common", + "apache-airflow-shared-module-loading", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/airflow_shared"] + +[tool.ruff] +extend = "../../pyproject.toml" +src = ["src"] + +[tool.ruff.lint.per-file-ignores] +# Ignore Doc rules et al for anything outside of tests +"!src/*" = ["D", "S101", "TRY002"] + +[tool.ruff.lint.flake8-tidy-imports] +# Override the workspace level default +ban-relative-imports = "parents" diff --git a/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py b/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py new file mode 100644 index 0000000000000..a7f1811043524 --- /dev/null +++ b/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from .providers_discovery import ( + KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS as KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS, + HookClassProvider as HookClassProvider, + HookInfo as HookInfo, + LazyDictWithCache as LazyDictWithCache, + PluginInfo as PluginInfo, + ProviderInfo as ProviderInfo, + _check_builtin_provider_prefix as _check_builtin_provider_prefix, + _create_provider_info_schema_validator as _create_provider_info_schema_validator, + discover_all_providers_from_packages as discover_all_providers_from_packages, + log_import_warning as log_import_warning, + log_optional_feature_disabled as log_optional_feature_disabled, + provider_info_cache as provider_info_cache, +) diff --git a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py new file mode 100644 index 0000000000000..4fc882d1b5d23 --- /dev/null +++ b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py @@ -0,0 +1,348 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Shared provider discovery utilities.""" + +from __future__ import annotations + +import contextlib +import json +import pathlib +from collections.abc import Callable, MutableMapping +from dataclasses import dataclass +from functools import wraps +from importlib.resources import files as resource_files +from time import perf_counter +from typing import Any, NamedTuple, ParamSpec + +import structlog +from packaging.utils import canonicalize_name + +from ..module_loading import entry_points_with_dist + +log = structlog.getLogger(__name__) + + +PS = ParamSpec("PS") + + +KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")] + + +@dataclass +class ProviderInfo: + """ + Provider information. + + :param version: version string + :param data: dictionary with information about the provider + """ + + version: str + data: dict + + +class HookClassProvider(NamedTuple): + """Hook class and Provider it comes from.""" + + hook_class_name: str + package_name: str + + +class HookInfo(NamedTuple): + """Hook information.""" + + hook_class_name: str + connection_id_attribute_name: str + package_name: str + hook_name: str + connection_type: str + connection_testable: bool + dialects: list[str] = [] + + +class ConnectionFormWidgetInfo(NamedTuple): + """Connection Form Widget information.""" + + hook_class_name: str + package_name: str + field: Any + field_name: str + is_sensitive: bool + + +class PluginInfo(NamedTuple): + """Plugin class, name and provider it comes from.""" + + name: str + plugin_class: str + provider_name: str + + +class NotificationInfo(NamedTuple): + """Notification class and provider it comes from.""" + + notification_class_name: str + package_name: str + + +class TriggerInfo(NamedTuple): + """Trigger class and provider it comes from.""" + + trigger_class_name: str + package_name: str + integration_name: str + + +class DialectInfo(NamedTuple): + """Dialect class and Provider it comes from.""" + + name: str + dialect_class_name: str + provider_name: str + + +class LazyDictWithCache(MutableMapping): + """ + Lazy-loaded cached dictionary. + + Dictionary, which in case you set callable, executes the passed callable with `key` attribute + at first use - and returns and caches the result. + """ + + __slots__ = ["_resolved", "_raw_dict"] + + def __init__(self, *args, **kw): + self._resolved = set() + self._raw_dict = dict(*args, **kw) + + def __setitem__(self, key, value): + self._raw_dict.__setitem__(key, value) + + def __getitem__(self, key): + value = self._raw_dict.__getitem__(key) + if key not in self._resolved and callable(value): + # exchange callable with result of calling it -- but only once! allow resolver to return a + # callable itself + value = value() + self._resolved.add(key) + self._raw_dict.__setitem__(key, value) + return value + + def __delitem__(self, key): + with contextlib.suppress(KeyError): + self._resolved.remove(key) + self._raw_dict.__delitem__(key) + + def __iter__(self): + return iter(self._raw_dict) + + def __len__(self): + return len(self._raw_dict) + + def __contains__(self, key): + return key in self._raw_dict + + def clear(self): + self._resolved.clear() + self._raw_dict.clear() + + +def _read_schema_from_resources_or_local_file(filename: str) -> dict: + """Read JSON schema from resources or local file.""" + try: + with resource_files("airflow").joinpath(filename).open("rb") as f: + schema = json.load(f) + except (TypeError, FileNotFoundError): + with (pathlib.Path(__file__).parent / filename).open("rb") as f: + schema = json.load(f) + return schema + + +def _create_provider_info_schema_validator(): + """Create JSON schema validator from the provider_info.schema.json.""" + import jsonschema + + schema = _read_schema_from_resources_or_local_file("provider_info.schema.json") + cls = jsonschema.validators.validator_for(schema) + validator = cls(schema) + return validator + + +def _create_customized_form_field_behaviours_schema_validator(): + """Create JSON schema validator from the customized_form_field_behaviours.schema.json.""" + import jsonschema + + schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json") + cls = jsonschema.validators.validator_for(schema) + validator = cls(schema) + return validator + + +def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool: + """Check if builtin provider class has correct prefix.""" + if provider_package.startswith("apache-airflow"): + provider_path = provider_package[len("apache-") :].replace("-", ".") + if not class_name.startswith(provider_path): + log.warning( + "Coherence check failed when importing '%s' from '%s' package. It should start with '%s'", + class_name, + provider_package, + provider_path, + ) + return False + return True + + +def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str): + """ + Verify the correct placeholder prefix. + + If the given field_behaviors dict contains a placeholder's node, and there + are placeholders for extra fields (i.e. anything other than the built-in conn + attrs), and if those extra fields are unprefixed, then add the prefix. + + The reason we need to do this is, all custom conn fields live in the same dictionary, + so we need to namespace them with a prefix internally. But for user convenience, + and consistency between the `get_ui_field_behaviour` method and the extra dict itself, + we allow users to supply the unprefixed name. + """ + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()} + + return field_behaviors + + +def log_optional_feature_disabled(class_name, e, provider_package): + """Log optional feature disabled.""" + log.debug( + "Optional feature disabled on exception when importing '%s' from '%s' package", + class_name, + provider_package, + exc_info=e, + ) + log.info( + "Optional provider feature disabled when importing '%s' from '%s' package", + class_name, + provider_package, + ) + + +def log_import_warning(class_name, e, provider_package): + """Log import warning.""" + log.warning( + "Exception when importing '%s' from '%s' package", + class_name, + provider_package, + exc_info=e, + ) + + +def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, None]], Callable[PS, None]]: + """ + Decorate and cache provider info. + + Decorator factory that create decorator that caches initialization of provider's parameters + :param cache_name: Name of the cache + """ + + def provider_info_cache_decorator(func: Callable[PS, None]) -> Callable[PS, None]: + @wraps(func) + def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None: + instance = args[0] + + if cache_name in instance._initialized_cache: + return + start_time = perf_counter() + log.debug("Initializing Provider Manager[%s]", cache_name) + func(*args, **kwargs) + instance._initialized_cache[cache_name] = True + log.debug( + "Initialization of Provider Manager[%s] took %.2f seconds", + cache_name, + perf_counter() - start_time, + ) + + return wrapped_function + + return provider_info_cache_decorator + + +def discover_all_providers_from_packages( + provider_dict: dict[str, ProviderInfo], + provider_schema_validator, +) -> None: + """ + Discover all providers by scanning packages installed. + + The list of providers should be returned via the 'apache_airflow_provider' + entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json' + schema. Note that the schema is different at runtime than provider.yaml.schema.json. + The development version of provider schema is more strict and changes together with + the code. The runtime version is more relaxed (allows for additional properties) + and verifies only the subset of fields that are needed at runtime. + + :param provider_dict: Dictionary to populate with discovered providers + :param provider_schema_validator: JSON schema validator for provider info + """ + for entry_point, dist in entry_points_with_dist("apache_airflow_provider"): + if not dist.metadata: + continue + package_name = canonicalize_name(dist.metadata["name"]) + if package_name in provider_dict: + continue + log.debug("Loading %s from package %s", entry_point, package_name) + version = dist.version + provider_info = entry_point.load()() + provider_schema_validator.validate(provider_info) + provider_info_package_name = provider_info["package-name"] + if package_name != provider_info_package_name: + raise ValueError( + f"The package '{package_name}' from packaging information " + f"{provider_info_package_name} do not match. Please make sure they are aligned" + ) + + # issue-59576: Retrieve the project.urls.documentation from dist.metadata + project_urls = dist.metadata.get_all("Project-URL") + documentation_url: str | None = None + + if project_urls: + for entry in project_urls: + if "," in entry: + name, url = entry.split(",") + if name.strip().lower() == "documentation": + documentation_url = url + break + + provider_info["documentation-url"] = documentation_url + + if package_name not in provider_dict: + provider_dict[package_name] = ProviderInfo(version, provider_info) + else: + log.warning( + "The provider for package '%s' could not be registered from because providers for that " + "package name have already been registered", + package_name, + ) diff --git a/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py b/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py new file mode 100644 index 0000000000000..87f97a3e5a2f2 --- /dev/null +++ b/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow_shared.providers_discovery import LazyDictWithCache + + +@pytest.mark.parametrize( + ("value", "expected_outputs"), + [ + ("a", "a"), + (1, 1), + (None, None), + (lambda: 0, 0), + (lambda: None, None), + (lambda: "z", "z"), + ], +) +def test_lazy_cache_dict_resolving(value, expected_outputs): + lazy_cache_dict = LazyDictWithCache() + lazy_cache_dict["key"] = value + assert lazy_cache_dict["key"] == expected_outputs + # Retrieve it again to see if it is correctly returned again + assert lazy_cache_dict["key"] == expected_outputs + + +def test_lazy_cache_dict_raises_error(): + def raise_method(): + raise RuntimeError("test") + + lazy_cache_dict = LazyDictWithCache() + lazy_cache_dict["key"] = raise_method + with pytest.raises(RuntimeError, match="test"): + _ = lazy_cache_dict["key"] + + +def test_lazy_cache_dict_del_item(): + lazy_cache_dict = LazyDictWithCache() + + def answer(): + return 42 + + lazy_cache_dict["spam"] = answer + assert "spam" in lazy_cache_dict._raw_dict + assert "spam" not in lazy_cache_dict._resolved # Not resoled yet + assert lazy_cache_dict["spam"] == 42 + assert "spam" in lazy_cache_dict._resolved + del lazy_cache_dict["spam"] + assert "spam" not in lazy_cache_dict._raw_dict + assert "spam" not in lazy_cache_dict._resolved + + lazy_cache_dict["foo"] = answer + assert lazy_cache_dict["foo"] == 42 + assert "foo" in lazy_cache_dict._resolved + # Emulate some mess in data, e.g. value from `_raw_dict` deleted but not from `_resolved` + del lazy_cache_dict._raw_dict["foo"] + assert "foo" in lazy_cache_dict._resolved + with pytest.raises(KeyError): + # Error expected here, but we still expect to remove also record into `resolved` + del lazy_cache_dict["foo"] + assert "foo" not in lazy_cache_dict._resolved + + lazy_cache_dict["baz"] = answer + # Key in `_resolved` not created yet + assert "baz" in lazy_cache_dict._raw_dict + assert "baz" not in lazy_cache_dict._resolved + del lazy_cache_dict._raw_dict["baz"] + assert "baz" not in lazy_cache_dict._raw_dict + assert "baz" not in lazy_cache_dict._resolved + + +def test_lazy_cache_dict_clear(): + def answer(): + return 42 + + lazy_cache_dict = LazyDictWithCache() + assert len(lazy_cache_dict) == 0 + lazy_cache_dict["spam"] = answer + lazy_cache_dict["foo"] = answer + lazy_cache_dict["baz"] = answer + + assert len(lazy_cache_dict) == 3 + assert len(lazy_cache_dict._raw_dict) == 3 + assert not lazy_cache_dict._resolved + assert lazy_cache_dict["spam"] == 42 + assert len(lazy_cache_dict._resolved) == 1 + # Emulate some mess in data, contain some data into the `_resolved` + lazy_cache_dict._resolved.add("biz") + assert len(lazy_cache_dict) == 3 + assert len(lazy_cache_dict._resolved) == 2 + # And finally cleanup everything + lazy_cache_dict.clear() + assert len(lazy_cache_dict) == 0 + assert not lazy_cache_dict._raw_dict + assert not lazy_cache_dict._resolved diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index fc989724391d8..350f901ee17bc 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -84,6 +84,9 @@ dependencies = [ 'importlib_metadata>=6.5;python_version<"3.12"', "pathspec>=0.9.0", # End of shared module-loading dependencies + # Start of shared providers-discovery dependencies + "jsonschema", + # End of shared providers-discovery dependencies ] [project.optional-dependencies] @@ -132,6 +135,7 @@ path = "src/airflow/sdk/__init__.py" "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/sdk/_shared/timezones" "../shared/listeners/src/airflow_shared/listeners" = "src/airflow/sdk/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager" +"../shared/providers_discovery/src/airflow_shared/providers_discovery" = "src/airflow/sdk/_shared/providers_discovery" [tool.hatch.build.targets.wheel] packages = ["src/airflow"] @@ -283,4 +287,5 @@ shared_distributions = [ "apache-airflow-shared-timezones", "apache-airflow-shared-observability", "apache-airflow-shared-plugins-manager", + "apache-airflow-shared-providers-discovery", ] diff --git a/task-sdk/src/airflow/sdk/_shared/providers_discovery b/task-sdk/src/airflow/sdk/_shared/providers_discovery new file mode 120000 index 0000000000000..b66ada0d22bbd --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/providers_discovery @@ -0,0 +1 @@ +../../../../../shared/providers_discovery/src/airflow_shared/providers_discovery \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index c4fdccc8e1be4..29eabafaf04e2 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -27,6 +27,8 @@ import attrs +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + if TYPE_CHECKING: from collections.abc import Collection from urllib.parse import SplitResult @@ -128,9 +130,8 @@ def normalize_noop(parts: SplitResult) -> SplitResult: def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop - from airflow.providers_manager import ProvidersManager - return ProvidersManager().asset_uri_handlers.get(scheme) + return ProvidersManagerTaskRuntime().asset_uri_handlers.get(scheme) def _get_normalized_scheme(uri: str) -> str: diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index bcf7937f03442..9ea239d0f0f97 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -26,6 +26,7 @@ import attrs from airflow.sdk.exceptions import AirflowException, AirflowNotFoundException, AirflowRuntimeError, ErrorType +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime log = logging.getLogger(__name__) @@ -188,10 +189,9 @@ def get_uri(self) -> str: def get_hook(self, *, hook_params=None): """Return hook based on conn_type.""" - from airflow.providers_manager import ProvidersManager from airflow.sdk._shared.module_loading import import_string - hook = ProvidersManager().hooks.get(self.conn_type, None) + hook = ProvidersManagerTaskRuntime().hooks.get(self.conn_type, None) if hook is None: raise AirflowException(f'Unknown hook type "{self.conn_type}"') diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py index 41a7c2d0bf290..7a4d3125e5fd8 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py @@ -18,12 +18,12 @@ from collections.abc import Callable -from airflow.providers_manager import ProvidersManager from airflow.sdk.bases.decorator import TaskDecorator from airflow.sdk.definitions.dag import dag from airflow.sdk.definitions.decorators.condition import run_if, skip_if from airflow.sdk.definitions.decorators.setup_teardown import setup_task, teardown_task from airflow.sdk.definitions.decorators.task_group import task_group +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime # Please keep this in sync with the .pyi's __all__. __all__ = [ @@ -47,7 +47,7 @@ def __getattr__(self, name: str) -> TaskDecorator: """Dynamically get provider-registered task decorators, e.g. ``@task.docker``.""" if name.startswith("__"): raise AttributeError(f"{type(self).__name__} has no attribute {name!r}") - decorators = ProvidersManager().taskflow_decorators + decorators = ProvidersManagerTaskRuntime().taskflow_decorators if name not in decorators: raise AttributeError(f"task decorator {name!r} not found") return decorators[name] diff --git a/task-sdk/src/airflow/sdk/io/fs.py b/task-sdk/src/airflow/sdk/io/fs.py index 524a6f767b27b..b51be36d48aa0 100644 --- a/task-sdk/src/airflow/sdk/io/fs.py +++ b/task-sdk/src/airflow/sdk/io/fs.py @@ -24,9 +24,9 @@ from fsspec.implementations.local import LocalFileSystem -from airflow.providers_manager import ProvidersManager from airflow.sdk._shared.module_loading import import_string from airflow.sdk._shared.observability.metrics.stats import Stats +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime if TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -55,7 +55,7 @@ def _register_filesystems() -> Mapping[ ]: scheme_to_fs = _BUILTIN_SCHEME_TO_FS.copy() with Stats.timer("airflow.io.load_filesystems") as timer: - manager = ProvidersManager() + manager = ProvidersManagerTaskRuntime() for fs_module_name in manager.filesystem_module_names: fs_module = import_string(fs_module_name) for scheme in getattr(fs_module, "schemes", []): diff --git a/task-sdk/src/airflow/sdk/plugins_manager.py b/task-sdk/src/airflow/sdk/plugins_manager.py index 603c0d23d0b4c..bdb9fd9ec0e8d 100644 --- a/task-sdk/src/airflow/sdk/plugins_manager.py +++ b/task-sdk/src/airflow/sdk/plugins_manager.py @@ -24,7 +24,6 @@ from typing import TYPE_CHECKING from airflow import settings -from airflow.providers_manager import ProvidersManager from airflow.sdk._shared.module_loading import import_string from airflow.sdk._shared.observability.metrics.stats import Stats from airflow.sdk._shared.plugins_manager import ( @@ -36,6 +35,7 @@ is_valid_plugin, ) from airflow.sdk.configuration import conf +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime if TYPE_CHECKING: from airflow.listeners.listener import ListenerManager @@ -46,7 +46,7 @@ def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: """Load plugins from providers.""" log.debug("Loading plugins from providers") - providers_manager = ProvidersManager() + providers_manager = ProvidersManagerTaskRuntime() providers_manager.initialize_providers_plugins() plugins: list[AirflowPlugin] = [] diff --git a/task-sdk/src/airflow/sdk/providers_manager_runtime.py b/task-sdk/src/airflow/sdk/providers_manager_runtime.py new file mode 100644 index 0000000000000..e7b1b65e2bf12 --- /dev/null +++ b/task-sdk/src/airflow/sdk/providers_manager_runtime.py @@ -0,0 +1,613 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manages runtime provider resources for task execution.""" + +from __future__ import annotations + +import functools +import inspect +import traceback +import warnings +from collections.abc import Callable, MutableMapping +from typing import TYPE_CHECKING, Any +from urllib.parse import SplitResult + +import structlog + +from airflow.sdk._shared.module_loading import import_string +from airflow.sdk._shared.providers_discovery import ( + KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS, + HookClassProvider, + HookInfo, + LazyDictWithCache, + PluginInfo, + ProviderInfo, + _check_builtin_provider_prefix, + _create_provider_info_schema_validator, + discover_all_providers_from_packages, + log_import_warning, + log_optional_feature_disabled, + provider_info_cache, +) +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin +from airflow.sdk.exceptions import AirflowOptionalProviderFeatureException + +if TYPE_CHECKING: + from airflow.sdk import BaseHook + from airflow.sdk.bases.decorator import TaskDecorator + from airflow.sdk.definitions.asset import Asset + +log = structlog.getLogger(__name__) + + +def _correctness_check(provider_package: str, class_name: str, provider_info: ProviderInfo) -> Any: + """ + Perform coherence check on provider classes. + + For apache-airflow providers - it checks if it starts with appropriate package. For all providers + it tries to import the provider - checking that there are no exceptions during importing. + It logs appropriate warning in case it detects any problems. + + :param provider_package: name of the provider package + :param class_name: name of the class to import + + :return the class if the class is OK, None otherwise. + """ + if not _check_builtin_provider_prefix(provider_package, class_name): + return None + try: + imported_class = import_string(class_name) + except AirflowOptionalProviderFeatureException as e: + # When the provider class raises AirflowOptionalProviderFeatureException + # this is an expected case when only some classes in provider are + # available. We just log debug level here and print info message in logs so that + # the user is aware of it + log_optional_feature_disabled(class_name, e, provider_package) + return None + except ImportError as e: + if "No module named 'airflow.providers." in e.msg: + # handle cases where another provider is missing. This can only happen if + # there is an optional feature, so we log debug and print information about it + log_optional_feature_disabled(class_name, e, provider_package) + return None + for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS: + # Until we convert all providers to use AirflowOptionalProviderFeatureException + # we assume any problem with importing another "provider" is because this is an + # optional feature, so we log debug and print information about it + if known_error[0] == provider_package and known_error[1] in e.msg: + log_optional_feature_disabled(class_name, e, provider_package) + return None + # But when we have no idea - we print warning to logs + log_import_warning(class_name, e, provider_package) + return None + except Exception as e: + log_import_warning(class_name, e, provider_package) + return None + return imported_class + + +class ProvidersManagerTaskRuntime(LoggingMixin): + """ + Manages runtime provider resources for task execution. + + This is a Singleton class. The first time it is instantiated, it discovers all available + runtime provider resources (hooks, taskflow decorators, filesystems, asset handlers). + """ + + resource_version = "0" + _initialized: bool = False + _initialization_stack_trace = None + _instance: ProvidersManagerTaskRuntime | None = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @staticmethod + def initialized() -> bool: + return ProvidersManagerTaskRuntime._initialized + + @staticmethod + def initialization_stack_trace() -> str | None: + return ProvidersManagerTaskRuntime._initialization_stack_trace + + def __init__(self): + """Initialize the runtime manager.""" + # skip initialization if already initialized + if self.initialized(): + return + super().__init__() + ProvidersManagerTaskRuntime._initialized = True + ProvidersManagerTaskRuntime._initialization_stack_trace = "".join( + traceback.format_stack(inspect.currentframe()) + ) + self._initialized_cache: dict[str, bool] = {} + # Keeps dict of providers keyed by module name + self._provider_dict: dict[str, ProviderInfo] = {} + self._fs_set: set[str] = set() + self._asset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} + self._asset_factories: dict[str, Callable[..., Asset]] = {} + self._asset_to_openlineage_converters: dict[str, Callable] = {} + self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() + # keeps mapping between connection_types and hook class, package they come from + self._hook_provider_dict: dict[str, HookClassProvider] = {} + # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time + self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache() + self._plugins_set: set[PluginInfo] = set() + self._provider_schema_validator = _create_provider_info_schema_validator() + self._init_airflow_core_hooks() + + def _init_airflow_core_hooks(self): + """Initialize the hooks dict with default hooks from Airflow core.""" + core_dummy_hooks = { + "generic": "Generic", + "email": "Email", + } + for key, display in core_dummy_hooks.items(): + self._hooks_lazy_dict[key] = HookInfo( + hook_class_name=None, + connection_id_attribute_name=None, + package_name=None, + hook_name=display, + connection_type=None, + connection_testable=False, + ) + for conn_type, class_name in ( + ("fs", "airflow.providers.standard.hooks.filesystem.FSHook"), + ("package_index", "airflow.providers.standard.hooks.package_index.PackageIndexHook"), + ): + self._hooks_lazy_dict[conn_type] = functools.partial( + self._import_hook, + connection_type=None, + package_name="apache-airflow-providers-standard", + hook_class_name=class_name, + provider_info=None, + ) + + @provider_info_cache("list") + def initialize_providers_list(self): + """Lazy initialization of providers list.""" + discover_all_providers_from_packages(self._provider_dict, self._provider_schema_validator) + self._provider_dict = dict(sorted(self._provider_dict.items())) + + @provider_info_cache("hooks") + def initialize_providers_hooks(self): + """Lazy initialization of providers hooks.""" + self._init_airflow_core_hooks() + self.initialize_providers_list() + self._discover_hooks() + self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items())) + + @provider_info_cache("filesystems") + def initialize_providers_filesystems(self): + """Lazy initialization of providers filesystems.""" + self.initialize_providers_list() + self._discover_filesystems() + + @provider_info_cache("asset_uris") + def initialize_providers_asset_uri_resources(self): + """Lazy initialization of provider asset URI handlers, factories, converters etc.""" + self.initialize_providers_list() + self._discover_asset_uri_resources() + + @provider_info_cache("plugins") + def initialize_providers_plugins(self): + """Lazy initialization of providers plugins.""" + self.initialize_providers_list() + self._discover_plugins() + + @provider_info_cache("taskflow_decorators") + def initialize_providers_taskflow_decorator(self): + """Lazy initialization of providers taskflow decorators.""" + self.initialize_providers_list() + self._discover_taskflow_decorators() + + def _discover_hooks_from_connection_types( + self, + hook_class_names_registered: set[str], + already_registered_warning_connection_types: set[str], + package_name: str, + provider: ProviderInfo, + ): + """ + Discover hooks from the "connection-types" property. + + This is new, better method that replaces discovery from hook-class-names as it + allows to lazy import individual Hook classes when they are accessed. + The "connection-types" keeps information about both - connection type and class + name so we can discover all connection-types without importing the classes. + :param hook_class_names_registered: set of registered hook class names for this provider + :param already_registered_warning_connection_types: set of connections for which warning should be + printed in logs as they were already registered before + :param package_name: + :param provider: + :return: + """ + provider_uses_connection_types = False + connection_types = provider.data.get("connection-types") + if connection_types: + for connection_type_dict in connection_types: + connection_type = connection_type_dict["connection-type"] + hook_class_name = connection_type_dict["hook-class-name"] + hook_class_names_registered.add(hook_class_name) + already_registered = self._hook_provider_dict.get(connection_type) + if already_registered: + if already_registered.package_name != package_name: + already_registered_warning_connection_types.add(connection_type) + else: + log.warning( + "The connection type '%s' is already registered in the" + " package '%s' with different class names: '%s' and '%s'. ", + connection_type, + package_name, + already_registered.hook_class_name, + hook_class_name, + ) + else: + self._hook_provider_dict[connection_type] = HookClassProvider( + hook_class_name=hook_class_name, package_name=package_name + ) + # Defer importing hook to access time by setting import hook method as dict value + self._hooks_lazy_dict[connection_type] = functools.partial( + self._import_hook, + connection_type=connection_type, + provider_info=provider, + ) + provider_uses_connection_types = True + return provider_uses_connection_types + + def _discover_hooks_from_hook_class_names( + self, + hook_class_names_registered: set[str], + already_registered_warning_connection_types: set[str], + package_name: str, + provider: ProviderInfo, + provider_uses_connection_types: bool, + ): + """ + Discover hooks from "hook-class-names' property. + + This property is deprecated but we should support it in Airflow 2. + The hook-class-names array contained just Hook names without connection type, + therefore we need to import all those classes immediately to know which connection types + are supported. This makes it impossible to selectively only import those hooks that are used. + :param already_registered_warning_connection_types: list of connection hooks that we should warn + about when finished discovery + :param package_name: name of the provider package + :param provider: class that keeps information about version and details of the provider + :param provider_uses_connection_types: determines whether the provider uses "connection-types" new + form of passing connection types + :return: + """ + hook_class_names = provider.data.get("hook-class-names") + if hook_class_names: + for hook_class_name in hook_class_names: + if hook_class_name in hook_class_names_registered: + # Silently ignore the hook class - it's already marked for lazy-import by + # connection-types discovery + continue + hook_info = self._import_hook( + connection_type=None, + provider_info=provider, + hook_class_name=hook_class_name, + package_name=package_name, + ) + if not hook_info: + # Problem why importing class - we ignore it. Log is written at import time + continue + already_registered = self._hook_provider_dict.get(hook_info.connection_type) + if already_registered: + if already_registered.package_name != package_name: + already_registered_warning_connection_types.add(hook_info.connection_type) + else: + if already_registered.hook_class_name != hook_class_name: + log.warning( + "The hook connection type '%s' is registered twice in the" + " package '%s' with different class names: '%s' and '%s'. " + " Please fix it!", + hook_info.connection_type, + package_name, + already_registered.hook_class_name, + hook_class_name, + ) + else: + self._hook_provider_dict[hook_info.connection_type] = HookClassProvider( + hook_class_name=hook_class_name, package_name=package_name + ) + self._hooks_lazy_dict[hook_info.connection_type] = hook_info + + if not provider_uses_connection_types: + warnings.warn( + f"The provider {package_name} uses `hook-class-names` " + "property in provider-info and has no `connection-types` one. " + "The 'hook-class-names' property has been deprecated in favour " + "of 'connection-types' in Airflow 2.2. Use **both** in case you want to " + "have backwards compatibility with Airflow < 2.2", + DeprecationWarning, + stacklevel=1, + ) + for already_registered_connection_type in already_registered_warning_connection_types: + log.warning( + "The connection_type '%s' has been already registered by provider '%s.'", + already_registered_connection_type, + self._hook_provider_dict[already_registered_connection_type].package_name, + ) + + def _discover_hooks(self) -> None: + """Retrieve all connections defined in the providers via Hooks.""" + for package_name, provider in self._provider_dict.items(): + duplicated_connection_types: set[str] = set() + hook_class_names_registered: set[str] = set() + provider_uses_connection_types = self._discover_hooks_from_connection_types( + hook_class_names_registered, duplicated_connection_types, package_name, provider + ) + self._discover_hooks_from_hook_class_names( + hook_class_names_registered, + duplicated_connection_types, + package_name, + provider, + provider_uses_connection_types, + ) + self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items())) + + @staticmethod + def _get_attr(obj: Any, attr_name: str): + """Retrieve attributes of an object, or warn if not found.""" + if not hasattr(obj, attr_name): + log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name) + return None + return getattr(obj, attr_name) + + def _import_hook( + self, + connection_type: str | None, + provider_info: ProviderInfo, + hook_class_name: str | None = None, + package_name: str | None = None, + ) -> HookInfo | None: + """ + Import hook and retrieve hook information. + + Either connection_type (for lazy loading) or hook_class_name must be set - but not both). + Only needs package_name if hook_class_name is passed (for lazy loading, package_name + is retrieved from _connection_type_class_provider_dict together with hook_class_name). + + :param connection_type: type of the connection + :param hook_class_name: name of the hook class + :param package_name: provider package - only needed in case connection_type is missing + : return + """ + if connection_type is None and hook_class_name is None: + raise ValueError("Either connection_type or hook_class_name must be set") + if connection_type is not None and hook_class_name is not None: + raise ValueError( + f"Both connection_type ({connection_type} and " + f"hook_class_name {hook_class_name} are set. Only one should be set!" + ) + if connection_type is not None: + class_provider = self._hook_provider_dict[connection_type] + package_name = class_provider.package_name + hook_class_name = class_provider.hook_class_name + else: + if not hook_class_name: + raise ValueError("Either connection_type or hook_class_name must be set") + if not package_name: + raise ValueError( + f"Provider package name is not set when hook_class_name ({hook_class_name}) is used" + ) + hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info) + if hook_class is None: + return None + + hook_connection_type = self._get_attr(hook_class, "conn_type") + if connection_type: + if hook_connection_type != connection_type: + log.warning( + "Inconsistency! The hook class '%s' declares connection type '%s'" + " but it is added by provider '%s' as connection_type '%s' in provider info. " + "This should be fixed!", + hook_class, + hook_connection_type, + package_name, + connection_type, + ) + connection_type = hook_connection_type + connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr") + hook_name: str = self._get_attr(hook_class, "hook_name") + + if not connection_type or not connection_id_attribute_name or not hook_name: + log.warning( + "The hook misses one of the key attributes: " + "conn_type: %s, conn_id_attribute_name: %s, hook_name: %s", + connection_type, + connection_id_attribute_name, + hook_name, + ) + return None + + return HookInfo( + hook_class_name=hook_class_name, + connection_id_attribute_name=connection_id_attribute_name, + package_name=package_name, + hook_name=hook_name, + connection_type=connection_type, + connection_testable=hasattr(hook_class, "test_connection"), + ) + + def _discover_filesystems(self) -> None: + """Retrieve all filesystems defined in the providers.""" + for provider_package, provider in self._provider_dict.items(): + for fs_module_name in provider.data.get("filesystems", []): + if _correctness_check(provider_package, f"{fs_module_name}.get_fs", provider): + self._fs_set.add(fs_module_name) + self._fs_set = set(sorted(self._fs_set)) + + def _discover_asset_uri_resources(self) -> None: + """Discovers and registers asset URI handlers, factories, and converters for all providers.""" + from airflow.sdk.definitions.asset import normalize_noop + + def _safe_register_resource( + provider_package_name: str, + schemes_list: list[str], + resource_path: str | None, + resource_registry: dict, + default_resource: Any = None, + ): + """ + Register a specific resource (handler, factory, or converter) for the given schemes. + + If the resolved resource (either from the path or the default) is valid, it updates + the resource registry with the appropriate resource for each scheme. + """ + resource = ( + _correctness_check(provider_package_name, resource_path, provider) + if resource_path is not None + else default_resource + ) + if resource: + resource_registry.update((scheme, resource) for scheme in schemes_list) + + for provider_name, provider in self._provider_dict.items(): + for uri_info in provider.data.get("asset-uris", []): + if "schemes" not in uri_info or "handler" not in uri_info: + continue # Both schemas and handler must be explicitly set, handler can be set to null + common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name} + _safe_register_resource( + resource_path=uri_info["handler"], + resource_registry=self._asset_uri_handlers, + default_resource=normalize_noop, + **common_args, + ) + _safe_register_resource( + resource_path=uri_info.get("factory"), + resource_registry=self._asset_factories, + **common_args, + ) + _safe_register_resource( + resource_path=uri_info.get("to_openlineage_converter"), + resource_registry=self._asset_to_openlineage_converters, + **common_args, + ) + + def _discover_plugins(self) -> None: + """Retrieve all plugins defined in the providers.""" + for provider_package, provider in self._provider_dict.items(): + for plugin_dict in provider.data.get("plugins", ()): + if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider): + log.warning("Plugin not loaded due to above correctness check problem.") + continue + self._plugins_set.add( + PluginInfo( + name=plugin_dict["name"], + plugin_class=plugin_dict["plugin-class"], + provider_name=provider_package, + ) + ) + + def _discover_taskflow_decorators(self) -> None: + for name, info in self._provider_dict.items(): + for taskflow_decorator in info.data.get("task-decorators", []): + self._add_taskflow_decorator( + taskflow_decorator["name"], taskflow_decorator["class-name"], name + ) + + def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None: + if not _check_builtin_provider_prefix(provider_package, decorator_class_name): + return + + if name in self._taskflow_decorators: + try: + existing = self._taskflow_decorators[name] + other_name = f"{existing.__module__}.{existing.__name__}" + except Exception: + # If problem importing, then get the value from the functools.partial + other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined] + + log.warning( + "The taskflow decorator '%s' has been already registered (by %s).", + name, + other_name, + ) + return + + self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name) + + @property + def providers(self) -> dict[str, ProviderInfo]: + """Returns information about available providers.""" + self.initialize_providers_list() + return self._provider_dict + + @property + def hooks(self) -> MutableMapping[str, HookInfo | None]: + """ + Return dictionary of connection_type-to-hook mapping. + + Note that the dict can contain None values if a hook discovered cannot be imported! + """ + self.initialize_providers_hooks() + return self._hooks_lazy_dict + + @property + def taskflow_decorators(self) -> dict[str, TaskDecorator]: + self.initialize_providers_taskflow_decorator() + return self._taskflow_decorators # type: ignore[return-value] + + @property + def filesystem_module_names(self) -> list[str]: + self.initialize_providers_filesystems() + return sorted(self._fs_set) + + @property + def asset_factories(self) -> dict[str, Callable[..., Asset]]: + self.initialize_providers_asset_uri_resources() + return self._asset_factories + + @property + def asset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: + self.initialize_providers_asset_uri_resources() + return self._asset_uri_handlers + + @property + def asset_to_openlineage_converters( + self, + ) -> dict[str, Callable]: + self.initialize_providers_asset_uri_resources() + return self._asset_to_openlineage_converters + + @property + def plugins(self) -> list[PluginInfo]: + """Returns information about plugins available in providers.""" + self.initialize_providers_plugins() + return sorted(self._plugins_set, key=lambda x: x.plugin_class) + + def _cleanup(self): + self._initialized_cache.clear() + self._provider_dict.clear() + self._fs_set.clear() + self._taskflow_decorators.clear() + self._hook_provider_dict.clear() + self._hooks_lazy_dict.clear() + self._plugins_set.clear() + self._asset_uri_handlers.clear() + self._asset_factories.clear() + self._asset_to_openlineage_converters.clear() + + self._initialized = False + self._initialization_stack_trace = None diff --git a/task-sdk/tests/task_sdk/definitions/test_connection.py b/task-sdk/tests/task_sdk/definitions/test_connection.py index 8fca258ec7936..d7811f491c962 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connection.py +++ b/task-sdk/tests/task_sdk/definitions/test_connection.py @@ -36,7 +36,7 @@ class TestConnections: @pytest.fixture def mock_providers_manager(self): """Mock the ProvidersManager to return predefined hooks.""" - with mock.patch("airflow.providers_manager.ProvidersManager") as mock_manager: + with mock.patch("airflow.sdk.definitions.connection.ProvidersManagerTaskRuntime") as mock_manager: yield mock_manager @mock.patch("airflow.sdk._shared.module_loading.import_string") diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index e7f653d76a263..2186cd32b1243 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -61,6 +61,7 @@ def test_airflow_sdk_no_unexpected_exports(): "observability", "plugins_manager", "listener", + "providers_manager_runtime", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}" diff --git a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py new file mode 100644 index 0000000000000..da6600a6fda5d --- /dev/null +++ b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py @@ -0,0 +1,238 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import sys +import warnings +from unittest.mock import patch + +import pytest + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.sdk._shared.providers_discovery import ( + HookClassProvider, + LazyDictWithCache, + ProviderInfo, +) +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + +from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker, skip_if_not_on_main +from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH + +PY313 = sys.version_info >= (3, 13) + + +def test_cleanup_providers_manager_runtime(cleanup_providers_manager): + """Check the cleanup provider manager functionality.""" + provider_manager = ProvidersManagerTaskRuntime() + # Check by type name since symlinks create different module paths + assert type(provider_manager.hooks).__name__ == "LazyDictWithCache" + hooks = provider_manager.hooks + ProvidersManagerTaskRuntime()._cleanup() + assert not len(hooks) + assert ProvidersManagerTaskRuntime().hooks is hooks + + +@skip_if_force_lowest_dependencies_marker +class TestProvidersManagerRuntime: + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog, cleanup_providers_manager_runtime): + self._caplog = caplog + + def test_hooks_deprecation_warnings_generated(self): + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._provider_dict["test-package"] = ProviderInfo( + version="0.0.1", + data={"hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"]}, + ) + with pytest.warns(expected_warning=DeprecationWarning, match="hook-class-names") as warning_records: + providers_manager._discover_hooks() + assert warning_records + + def test_hooks_deprecation_warnings_not_generated(self): + with warnings.catch_warnings(record=True) as warning_records: + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( + version="0.0.1", + data={ + "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], + "connection-types": [ + { + "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", + "connection-type": "sftp", + } + ], + }, + ) + providers_manager._discover_hooks() + assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] + + def test_warning_logs_generated(self): + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._hooks_lazy_dict = LazyDictWithCache() + with self._caplog.at_level(logging.WARNING): + providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( + version="0.0.1", + data={ + "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], + "connection-types": [ + { + "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", + "connection-type": "wrong-connection-type", + } + ], + }, + ) + providers_manager._discover_hooks() + _ = providers_manager._hooks_lazy_dict["wrong-connection-type"] + assert len(self._caplog.entries) == 1 + assert "Inconsistency!" in self._caplog[0]["event"] + assert "sftp" not in providers_manager._hooks_lazy_dict + + def test_warning_logs_not_generated(self): + with self._caplog.at_level(logging.WARNING): + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._provider_dict["apache-airflow-providers-sftp"] = ProviderInfo( + version="0.0.1", + data={ + "hook-class-names": ["airflow.providers.sftp.hooks.sftp.SFTPHook"], + "connection-types": [ + { + "hook-class-name": "airflow.providers.sftp.hooks.sftp.SFTPHook", + "connection-type": "sftp", + } + ], + }, + ) + providers_manager._discover_hooks() + _ = providers_manager._hooks_lazy_dict["sftp"] + assert not self._caplog.records + assert "sftp" in providers_manager.hooks + + def test_already_registered_conn_type_in_provide(self): + with self._caplog.at_level(logging.WARNING): + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._provider_dict["apache-airflow-providers-dummy"] = ProviderInfo( + version="0.0.1", + data={ + "connection-types": [ + { + "hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook", + "connection-type": "dummy", + }, + { + "hook-class-name": "airflow.providers.dummy.hooks.dummy.DummyHook2", + "connection-type": "dummy", + }, + ], + }, + ) + providers_manager._discover_hooks() + _ = providers_manager._hooks_lazy_dict["dummy"] + assert len(self._caplog.records) == 1 + msg = self._caplog.messages[0] + assert msg.startswith("The connection type 'dummy' is already registered") + assert ( + "different class names: 'airflow.providers.dummy.hooks.dummy.DummyHook'" + " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'." + ) in msg + + def test_hooks(self): + with warnings.catch_warnings(record=True) as warning_records: + with self._caplog.at_level(logging.WARNING): + provider_manager = ProvidersManagerTaskRuntime() + connections_list = list(provider_manager.hooks.keys()) + assert len(connections_list) > 60 + if len(self._caplog.records) != 0: + for record in self._caplog.records: + print(record.message, file=sys.stderr) + print(record.exc_info, file=sys.stderr) + raise AssertionError("There are warnings generated during hook imports. Please fix them") + assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] + + @skip_if_not_on_main + @pytest.mark.execution_timeout(150) + def test_hook_values(self): + provider_dependencies = json.loads( + (AIRFLOW_ROOT_PATH / "generated" / "provider_dependencies.json").read_text() + ) + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + excluded_providers: list[str] = [] + for provider_name, provider_info in provider_dependencies.items(): + if python_version in provider_info.get("excluded-python-versions", []): + excluded_providers.append(f"apache-airflow-providers-{provider_name.replace('.', '-')}") + with warnings.catch_warnings(record=True) as warning_records: + with self._caplog.at_level(logging.WARNING): + provider_manager = ProvidersManagerTaskRuntime() + connections_list = list(provider_manager.hooks.values()) + assert len(connections_list) > 60 + if len(self._caplog.records) != 0: + real_warning_count = 0 + for record in self._caplog.entries: + # When there is error importing provider that is excluded the provider name is in the message + if any(excluded_provider in record["event"] for excluded_provider in excluded_providers): + continue + print(record["event"], file=sys.stderr) + print(record.get("exc_info"), file=sys.stderr) + real_warning_count += 1 + if real_warning_count: + if PY313: + only_ydb_and_yandexcloud_warnings = True + for record in warning_records: + if "ydb" in str(record.message) or "yandexcloud" in str(record.message): + continue + only_ydb_and_yandexcloud_warnings = False + if only_ydb_and_yandexcloud_warnings: + print( + "Only warnings from ydb and yandexcloud providers are generated, " + "which is expected in Python 3.13+", + file=sys.stderr, + ) + return + raise AssertionError("There are warnings generated during hook imports. Please fix them") + assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] + + @patch("airflow.sdk.providers_manager_runtime.import_string") + def test_optional_feature_no_warning(self, mock_importlib_import_string): + with self._caplog.at_level(logging.WARNING): + mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException() + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._hook_provider_dict["test_connection"] = HookClassProvider( + package_name="test_package", hook_class_name="HookClass" + ) + providers_manager._import_hook( + hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection" + ) + assert self._caplog.messages == [] + + @patch("airflow.sdk.providers_manager_runtime.import_string") + def test_optional_feature_debug(self, mock_importlib_import_string): + with self._caplog.at_level(logging.INFO): + mock_importlib_import_string.side_effect = AirflowOptionalProviderFeatureException() + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._hook_provider_dict["test_connection"] = HookClassProvider( + package_name="test_package", hook_class_name="HookClass" + ) + providers_manager._import_hook( + hook_class_name=None, provider_info=None, package_name=None, connection_type="test_connection" + ) + assert self._caplog.messages == [ + "Optional provider feature disabled when importing 'HookClass' from 'test_package' package" + ]