diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 3f3cfd996e49e..f4e9f2022f571 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -231,6 +231,7 @@ exclude = [ "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/_shared/secrets_masker" "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/_shared/timezones" +"../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/_shared/plugins_manager" [tool.hatch.build.targets.custom] path = "./hatch_build.py" @@ -306,4 +307,5 @@ shared_distributions = [ "apache-airflow-shared-secrets-backend", "apache-airflow-shared-secrets-masker", "apache-airflow-shared-timezones", + "apache-airflow-shared-plugins-manager", ] diff --git a/airflow-core/src/airflow/_shared/plugins_manager b/airflow-core/src/airflow/_shared/plugins_manager new file mode 120000 index 0000000000000..ed5b6d4795048 --- /dev/null +++ b/airflow-core/src/airflow/_shared/plugins_manager @@ -0,0 +1 @@ +../../../../shared/plugins_manager/src/airflow_shared/plugins_manager \ No newline at end of file diff --git a/airflow-core/src/airflow/plugins_manager.py b/airflow-core/src/airflow/plugins_manager.py index fd4b538f36fd8..fa2d4cf41bcb2 100644 --- a/airflow-core/src/airflow/plugins_manager.py +++ b/airflow-core/src/airflow/plugins_manager.py @@ -19,25 +19,25 @@ from __future__ import annotations -import importlib.machinery -import importlib.util import inspect import logging -import os -import sys -import types from collections.abc import Iterable from functools import cache -from pathlib import Path from typing import TYPE_CHECKING, Any from airflow import settings from airflow._shared.module_loading import ( - entry_points_with_dist, - find_path_from_directory, import_string, qualname, ) +from airflow._shared.plugins_manager import ( + AirflowPlugin, + AirflowPluginSource as AirflowPluginSource, + PluginsDirectorySource as PluginsDirectorySource, + _load_entrypoint_plugins, + _load_plugins_from_plugin_directory, + is_valid_plugin, +) from airflow.configuration import conf from airflow.task.priority_strategy import ( PriorityWeightStrategy, @@ -46,210 +46,12 @@ if TYPE_CHECKING: from airflow.lineage.hook import HookLineageReader - - if sys.version_info >= (3, 12): - from importlib import metadata - else: - import importlib_metadata as metadata - from collections.abc import Generator - from types import ModuleType - from airflow.listeners.listener import ListenerManager from airflow.timetables.base import Timetable log = logging.getLogger(__name__) -class AirflowPluginSource: - """Class used to define an AirflowPluginSource.""" - - def __str__(self): - raise NotImplementedError - - def __html__(self): - raise NotImplementedError - - -class PluginsDirectorySource(AirflowPluginSource): - """Class used to define Plugins loaded from Plugins Directory.""" - - def __init__(self, path): - self.path = os.path.relpath(path, settings.PLUGINS_FOLDER) - - def __str__(self): - return f"$PLUGINS_FOLDER/{self.path}" - - def __html__(self): - return f"$PLUGINS_FOLDER/{self.path}" - - -class EntryPointSource(AirflowPluginSource): - """Class used to define Plugins loaded from entrypoint.""" - - def __init__(self, entrypoint: metadata.EntryPoint, dist: metadata.Distribution): - self.dist = dist.metadata["Name"] # type: ignore[index] - self.version = dist.version - self.entrypoint = str(entrypoint) - - def __str__(self): - return f"{self.dist}=={self.version}: {self.entrypoint}" - - def __html__(self): - return f"{self.dist}=={self.version}: {self.entrypoint}" - - -class AirflowPluginException(Exception): - """Exception when loading plugin.""" - - -class AirflowPlugin: - """Class used to define AirflowPlugin.""" - - name: str | None = None - source: AirflowPluginSource | None = None - macros: list[Any] = [] - admin_views: list[Any] = [] - flask_blueprints: list[Any] = [] - fastapi_apps: list[Any] = [] - fastapi_root_middlewares: list[Any] = [] - external_views: list[Any] = [] - react_apps: list[Any] = [] - menu_links: list[Any] = [] - appbuilder_views: list[Any] = [] - appbuilder_menu_items: list[Any] = [] - - # A list of global operator extra links that can redirect users to - # external systems. These extra links will be available on the - # task page in the form of buttons. - # - # Note: the global operator extra link can be overridden at each - # operator level. - global_operator_extra_links: list[Any] = [] - - # A list of operator extra links to override or add operator links - # to existing Airflow Operators. - # These extra links will be available on the task page in form of - # buttons. - operator_extra_links: list[Any] = [] - - # A list of timetable classes that can be used for DAG scheduling. - timetables: list[type[Timetable]] = [] - - # A list of listeners that can be used for tracking task and DAG states. - listeners: list[ModuleType | object] = [] - - # A list of hook lineage reader classes that can be used for reading lineage information from a hook. - hook_lineage_readers: list[type[HookLineageReader]] = [] - - # A list of priority weight strategy classes that can be used for calculating tasks weight priority. - priority_weight_strategies: list[type[PriorityWeightStrategy]] = [] - - @classmethod - def validate(cls): - """Validate if plugin has a name.""" - if not cls.name: - raise AirflowPluginException("Your plugin needs a name.") - - @classmethod - def on_load(cls, *args, **kwargs): - """ - Execute when the plugin is loaded; This method is only called once during runtime. - - :param args: If future arguments are passed in on call. - :param kwargs: If future arguments are passed in on call. - """ - - -def is_valid_plugin(plugin_obj) -> bool: - """ - Check whether a potential object is a subclass of the AirflowPlugin class. - - :param plugin_obj: potential subclass of AirflowPlugin - :return: Whether or not the obj is a valid subclass of - AirflowPlugin - """ - if ( - inspect.isclass(plugin_obj) - and issubclass(plugin_obj, AirflowPlugin) - and (plugin_obj is not AirflowPlugin) - ): - plugin_obj.validate() - return True - return False - - -def _load_entrypoint_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: - """ - Load and register plugins AirflowPlugin subclasses from the entrypoints. - - The entry_point group should be 'airflow.plugins'. - """ - log.debug("Loading plugins from entrypoints") - - plugins: list[AirflowPlugin] = [] - import_errors: dict[str, str] = {} - for entry_point, dist in entry_points_with_dist("airflow.plugins"): - log.debug("Importing entry_point plugin %s", entry_point.name) - try: - plugin_class = entry_point.load() - if not is_valid_plugin(plugin_class): - continue - - plugin_instance: AirflowPlugin = plugin_class() - plugin_instance.source = EntryPointSource(entry_point, dist) - plugins.append(plugin_instance) - except Exception as e: - log.exception("Failed to import plugin %s", entry_point.name) - import_errors[entry_point.module] = str(e) - return plugins, import_errors - - -def _load_plugins_from_plugin_directory() -> tuple[list[AirflowPlugin], dict[str, str]]: - """Load and register Airflow Plugins from plugins directory.""" - if settings.PLUGINS_FOLDER is None: - raise ValueError("Plugins folder is not set") - log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER) - ignore_file_syntax = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob") - files = find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore", ignore_file_syntax) - plugin_search_locations: list[tuple[str, Generator[str, None, None]]] = [("", files)] - - if conf.getboolean("core", "LOAD_EXAMPLES"): - log.debug("Note: Loading plugins from examples as well: %s", settings.PLUGINS_FOLDER) - from airflow.example_dags import plugins as example_plugins - - example_plugins_folder = next(iter(example_plugins.__path__)) - example_files = find_path_from_directory(example_plugins_folder, ".airflowignore", ignore_file_syntax) - plugin_search_locations.append((example_plugins.__name__, example_files)) - - plugins: list[AirflowPlugin] = [] - import_errors: dict[str, str] = {} - for module_prefix, plugin_files in plugin_search_locations: - for file_path in plugin_files: - path = Path(file_path) - if not path.is_file() or path.suffix != ".py": - continue - mod_name = f"{module_prefix}.{path.stem}" if module_prefix else path.stem - - try: - loader = importlib.machinery.SourceFileLoader(mod_name, file_path) - spec = importlib.util.spec_from_loader(mod_name, loader) - if not spec: - log.error("Could not load spec for module %s at %s", mod_name, file_path) - continue - mod = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = mod - loader.exec_module(mod) - - for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)): - plugin_instance: AirflowPlugin = mod_attr_value() - plugin_instance.source = PluginsDirectorySource(file_path) - plugins.append(plugin_instance) - except Exception as e: - log.exception("Failed to import plugin %s", file_path) - import_errors[file_path] = str(e) - return plugins, import_errors - - def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: from airflow.providers_manager import ProvidersManager @@ -273,19 +75,6 @@ def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: return plugins, import_errors -def make_module(name: str, objects: list[Any]) -> ModuleType | None: - """Create new module.""" - if not objects: - return None - log.debug("Creating module %s", name) - name = name.lower() - module = types.ModuleType(name) - module._name = name.split(".")[-1] # type: ignore - module._objects = objects # type: ignore - module.__dict__.update((o.__name__, o) for o in objects) - return module - - def ensure_plugins_loaded() -> None: """ Load plugins from plugins directory and entrypoints. @@ -329,7 +118,16 @@ def __register_plugins(plugin_instances: list[AirflowPlugin], errors: dict[str, import_errors.update(errors) with Stats.timer() as timer: - __register_plugins(*_load_plugins_from_plugin_directory()) + load_examples = conf.getboolean("core", "LOAD_EXAMPLES") + ignore_file_syntax = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob") + __register_plugins( + *_load_plugins_from_plugin_directory( + plugins_folder=settings.PLUGINS_FOLDER, + load_examples=load_examples, + example_plugins_module="airflow.example_dags.plugins" if load_examples else None, + ignore_file_syntax=ignore_file_syntax, + ) + ) __register_plugins(*_load_entrypoint_plugins()) if not settings.LAZY_LOAD_PROVIDERS: @@ -492,31 +290,27 @@ def get_hook_lineage_readers_plugins() -> list[type[HookLineageReader]]: @cache def integrate_macros_plugins() -> None: """Integrates macro plugins.""" + from airflow._shared.plugins_manager import ( + integrate_macros_plugins as _integrate_macros_plugins, + ) from airflow.sdk.execution_time import macros - log.debug("Integrate Macros plugins") - - for plugin in _get_plugins()[0]: - if plugin.name is None: - raise AirflowPluginException("Invalid plugin name") - - macros_module = make_module(f"airflow.sdk.execution_time.macros.{plugin.name}", plugin.macros) - - if macros_module: - sys.modules[macros_module.__name__] = macros_module - # Register the newly created module on airflow.macros such that it - # can be accessed when rendering templates. - setattr(macros, plugin.name, macros_module) + plugins, _ = _get_plugins() + _integrate_macros_plugins( + target_macros_module=macros, + macros_module_name_prefix="airflow.sdk.execution_time.macros", + plugins=plugins, + ) def integrate_listener_plugins(listener_manager: ListenerManager) -> None: """Add listeners from plugins.""" - for plugin in _get_plugins()[0]: - if plugin.name is None: - raise AirflowPluginException("Invalid plugin name") + from airflow._shared.plugins_manager import ( + integrate_listener_plugins as _integrate_listener_plugins, + ) - for listener in plugin.listeners: - listener_manager.add_listener(listener) + plugins, _ = _get_plugins() + _integrate_listener_plugins(listener_manager, plugins=plugins) def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str, Any]]: diff --git a/airflow-core/tests/unit/plugins/test_plugins_manager.py b/airflow-core/tests/unit/plugins/test_plugins_manager.py index 2b872f23c0bd3..182f912057af6 100644 --- a/airflow-core/tests/unit/plugins/test_plugins_manager.py +++ b/airflow-core/tests/unit/plugins/test_plugins_manager.py @@ -28,6 +28,7 @@ import pytest from airflow._shared.module_loading import qualname +from airflow.configuration import conf from airflow.listeners.listener import get_listener_manager from airflow.plugins_manager import AirflowPlugin @@ -88,7 +89,11 @@ def test_no_log_when_no_plugins(self, caplog): def test_loads_filesystem_plugins(self, caplog): from airflow import plugins_manager - plugins, import_errors = plugins_manager._load_plugins_from_plugin_directory() + plugins, import_errors = plugins_manager._load_plugins_from_plugin_directory( + plugins_folder=conf.get("core", "plugins_folder"), + load_examples=conf.getboolean("core", "load_examples"), + example_plugins_module="airflow.example_dags.plugins", + ) assert len(plugins) == 10 assert not import_errors @@ -266,38 +271,6 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): assert caplog.record_tuples == [] - def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_metadata_distribution, caplog): - """ - Test that Airflow does not raise an error if there is any Exception because of a plugin. - """ - from airflow.plugins_manager import _load_entrypoint_plugins - - mock_dist = mock.Mock() - mock_dist.metadata = {"Name": "test-dist"} - - mock_entrypoint = mock.Mock() - mock_entrypoint.name = "test-entrypoint" - mock_entrypoint.group = "airflow.plugins" - mock_entrypoint.module = "test.plugins.test_plugins_manager" - mock_entrypoint.load.side_effect = ImportError("my_fake_module not found") - mock_dist.entry_points = [mock_entrypoint] - - with ( - mock_metadata_distribution(return_value=[mock_dist]), - caplog.at_level(logging.ERROR, logger="airflow.plugins_manager"), - ): - _, import_errors = _load_entrypoint_plugins() - - received_logs = caplog.text - # Assert Traceback is shown too - assert "Traceback (most recent call last):" in received_logs - assert "my_fake_module not found" in received_logs - assert "Failed to import plugin test-entrypoint" in received_logs - assert ( - "test.plugins.test_plugins_manager", - "my_fake_module not found", - ) in import_errors.items() - def test_registering_plugin_macros(self, request): """ Tests whether macros that originate from plugins are being registered correctly. @@ -343,7 +316,13 @@ def test_registering_plugin_listeners(self): from airflow import plugins_manager assert not get_listener_manager().has_listeners - with mock_plugin_manager(plugins=plugins_manager._load_plugins_from_plugin_directory()[0]): + with mock_plugin_manager( + plugins=plugins_manager._load_plugins_from_plugin_directory( + plugins_folder=conf.get("core", "plugins_folder"), + load_examples=conf.getboolean("core", "load_examples"), + example_plugins_module="airflow.example_dags.plugins", + )[0] + ): plugins_manager.integrate_listener_plugins(get_listener_manager()) assert get_listener_manager().has_listeners @@ -381,35 +360,3 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): with mock.patch("airflow.plugins_manager._load_plugins_from_plugin_directory", return_value=([], [])): plugins = plugins_manager._get_plugins()[0] assert len(plugins) == 4 - - -class TestPluginsDirectorySource: - def test_should_return_correct_path_name(self): - from airflow import plugins_manager - - source = plugins_manager.PluginsDirectorySource(__file__) - assert source.path == "test_plugins_manager.py" - assert str(source) == "$PLUGINS_FOLDER/test_plugins_manager.py" - assert source.__html__() == "$PLUGINS_FOLDER/test_plugins_manager.py" - - -class TestEntryPointSource: - def test_should_return_correct_source_details(self, mock_metadata_distribution): - from airflow import plugins_manager - - mock_entrypoint = mock.Mock() - mock_entrypoint.name = "test-entrypoint-plugin" - mock_entrypoint.module = "module_name_plugin" - - mock_dist = mock.Mock() - mock_dist.metadata = {"Name": "test-entrypoint-plugin"} - mock_dist.version = "1.0.0" - mock_dist.entry_points = [mock_entrypoint] - - with mock_metadata_distribution(return_value=[mock_dist]): - plugins_manager._load_entrypoint_plugins() - - source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist) - assert str(mock_entrypoint) == source.entrypoint - assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source) - assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == source.__html__() diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/plugins/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/plugins/hive.py index 63a068be291a9..d000779e95e9d 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/plugins/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/plugins/hive.py @@ -17,8 +17,8 @@ from __future__ import annotations -from airflow.plugins_manager import AirflowPlugin from airflow.providers.apache.hive.macros.hive import closest_ds_partition, max_partition +from airflow.providers.common.compat.sdk import AirflowPlugin class HivePlugin(AirflowPlugin): diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py b/providers/common/compat/src/airflow/providers/common/compat/sdk.py index 9811d0c019a51..67670f8588669 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py +++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py @@ -94,6 +94,7 @@ ) from airflow.sdk.log import redact as redact from airflow.sdk.observability.stats import Stats as Stats + from airflow.sdk.plugins_manager import AirflowPlugin as AirflowPlugin # Airflow 3-only exceptions (conditionally imported) if AIRFLOW_V_3_0_PLUS: @@ -174,6 +175,10 @@ # ============================================================================ "BaseNotifier": ("airflow.sdk", "airflow.notifications.basenotifier"), # ============================================================================ + # Plugins + # ============================================================================ + "AirflowPlugin": ("airflow.sdk.plugins_manager", "airflow.plugins_manager"), + # ============================================================================ # Operator Links & Task Groups # ============================================================================ "BaseOperatorLink": ("airflow.sdk", "airflow.models.baseoperatorlink"), diff --git a/providers/databricks/pyproject.toml b/providers/databricks/pyproject.toml index 64f8485528712..2f76d840ce2b6 100644 --- a/providers/databricks/pyproject.toml +++ b/providers/databricks/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.10.1", # use next version + "apache-airflow-providers-common-compat>=1.10.1", # use next version "apache-airflow-providers-common-sql>=1.27.0", "requests>=2.32.0,<3", "databricks-sql-connector>=4.0.0", diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 905e79ec59bd1..dd595c49b7f7d 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -25,8 +25,13 @@ from airflow.exceptions import TaskInstanceNotFound from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceKey, clear_task_instances -from airflow.plugins_manager import AirflowPlugin -from airflow.providers.common.compat.sdk import AirflowException, BaseOperatorLink, TaskGroup, XCom +from airflow.providers.common.compat.sdk import ( + AirflowException, + AirflowPlugin, + BaseOperatorLink, + TaskGroup, + XCom, +) from airflow.providers.databricks.hooks.databricks import DatabricksHook from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py index 2058ae41a43c6..28b917a0d588a 100644 --- a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py @@ -36,8 +36,7 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstanceKey -from airflow.plugins_manager import AirflowPlugin -from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.compat.sdk import AirflowException, AirflowPlugin from airflow.providers.databricks.plugins.databricks_workflow import ( DatabricksWorkflowPlugin, RepairDatabricksTasks, diff --git a/providers/edge3/pyproject.toml b/providers/edge3/pyproject.toml index 714f1e139985c..58a277e73dc6b 100644 --- a/providers/edge3/pyproject.toml +++ b/providers/edge3/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.0,!=3.1.0", - "apache-airflow-providers-common-compat>=1.10.1", + "apache-airflow-providers-common-compat>=1.10.1", # use next version "pydantic>=2.11.0", "retryhttp>=1.2.0,!=1.3.0", ] diff --git a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py index b8aa49a642ff4..ad22956258c0e 100644 --- a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py +++ b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py @@ -22,7 +22,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException -from airflow.plugins_manager import AirflowPlugin +from airflow.providers.common.compat.sdk import AirflowPlugin from airflow.providers.edge3.version_compat import AIRFLOW_V_3_1_PLUS from airflow.utils.session import NEW_SESSION, provide_session diff --git a/providers/edge3/tests/unit/edge3/plugins/test_edge_executor_plugin.py b/providers/edge3/tests/unit/edge3/plugins/test_edge_executor_plugin.py index 0eea2ae9c0d12..536270480a909 100644 --- a/providers/edge3/tests/unit/edge3/plugins/test_edge_executor_plugin.py +++ b/providers/edge3/tests/unit/edge3/plugins/test_edge_executor_plugin.py @@ -21,7 +21,7 @@ import pytest -from airflow.plugins_manager import AirflowPlugin +from airflow.providers.common.compat.sdk import AirflowPlugin from airflow.providers.edge3.plugins import edge_executor_plugin from tests_common.test_utils.config import conf_vars diff --git a/providers/fab/pyproject.toml b/providers/fab/pyproject.toml index 3944d207069ab..1440cf276d01b 100644 --- a/providers/fab/pyproject.toml +++ b/providers/fab/pyproject.toml @@ -58,7 +58,7 @@ requires-python = ">=3.10,!=3.13" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.2", - "apache-airflow-providers-common-compat>=1.10.1", + "apache-airflow-providers-common-compat>=1.10.1", # use next version # Blinker use for signals in Flask, this is an optional dependency in Flask 2.2 and lower. # In Flask 2.3 it becomes a mandatory dependency, and flask signals are always available. "blinker>=1.6.2; python_version < '3.13'", diff --git a/providers/fab/tests/unit/fab/plugins/test_plugin.py b/providers/fab/tests/unit/fab/plugins/test_plugin.py index 20ac29f96ddc2..3ce031ee893d7 100644 --- a/providers/fab/tests/unit/fab/plugins/test_plugin.py +++ b/providers/fab/tests/unit/fab/plugins/test_plugin.py @@ -24,7 +24,7 @@ from starlette.middleware.base import BaseHTTPMiddleware # This is the class you derive to create a plugin -from airflow.plugins_manager import AirflowPlugin +from airflow.providers.common.compat.sdk import AirflowPlugin from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.timetables.interval import CronDataIntervalTimetable diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/openlineage.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/openlineage.py index 8e422d84ac604..374d8b2f06b3d 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/openlineage.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/openlineage.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.plugins_manager import AirflowPlugin +from airflow.providers.common.compat.sdk import AirflowPlugin from airflow.providers.openlineage import conf # Conditional imports - only load expensive dependencies when plugin is enabled diff --git a/pyproject.toml b/pyproject.toml index 1102d57e3b73f..41331960389d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1356,6 +1356,7 @@ apache-airflow-shared-secrets-backend = { workspace = true } apache-airflow-shared-secrets-masker = { workspace = true } apache-airflow-shared-timezones = { workspace = true } apache-airflow-shared-observability = { workspace = true } +apache-airflow-shared-plugins-manager = { workspace = true } # Automatically generated provider workspace items (update_airflow_pyproject_toml.py) apache-airflow-providers-airbyte = { workspace = true } apache-airflow-providers-alibaba = { workspace = true } @@ -1480,6 +1481,7 @@ members = [ "shared/secrets_masker", "shared/timezones", "shared/observability", + "shared/plugins_manager", # Automatically generated provider workspace members (update_airflow_pyproject_toml.py) "providers/airbyte", "providers/alibaba", diff --git a/shared/plugins_manager/pyproject.toml b/shared/plugins_manager/pyproject.toml new file mode 100644 index 0000000000000..059a552da29e7 --- /dev/null +++ b/shared/plugins_manager/pyproject.toml @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[project] +name = "apache-airflow-shared-plugins-manager" +description = "Shared plugins manager code for Airflow distributions" +version = "0.0" +classifiers = [ + "Private :: Do Not Upload", +] + +dependencies = [ + "pendulum>=3.1.0", + "methodtools>=0.4.7", + 'importlib_metadata>=6.5;python_version<"3.12"', + 'importlib_metadata>=7.0;python_version>="3.12"', +] + +[dependency-groups] +dev = [ + "apache-airflow-devel-common", + "apache-airflow-shared-module-loading", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/airflow_shared"] + +[tool.ruff] +extend = "../../pyproject.toml" +src = ["src"] + +[tool.ruff.lint.per-file-ignores] +# Ignore Doc rules et al for anything outside of tests +"!src/*" = ["D", "S101", "TRY002"] + +[tool.ruff.lint.flake8-tidy-imports] +# Override the workspace level default +ban-relative-imports = "parents" diff --git a/shared/plugins_manager/src/airflow_shared/plugins_manager/__init__.py b/shared/plugins_manager/src/airflow_shared/plugins_manager/__init__.py new file mode 100644 index 0000000000000..4fa713c2b4eb4 --- /dev/null +++ b/shared/plugins_manager/src/airflow_shared/plugins_manager/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from .plugins_manager import ( + AirflowPlugin as AirflowPlugin, + AirflowPluginException as AirflowPluginException, + AirflowPluginSource as AirflowPluginSource, + EntryPointSource as EntryPointSource, + PluginsDirectorySource as PluginsDirectorySource, + _load_entrypoint_plugins as _load_entrypoint_plugins, + _load_plugins_from_plugin_directory as _load_plugins_from_plugin_directory, + integrate_listener_plugins as integrate_listener_plugins, + integrate_macros_plugins as integrate_macros_plugins, + is_valid_plugin as is_valid_plugin, + make_module as make_module, +) diff --git a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py new file mode 100644 index 0000000000000..86eb323067597 --- /dev/null +++ b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py @@ -0,0 +1,303 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manages all plugins.""" + +from __future__ import annotations + +import inspect +import logging +import os +import sys +import types +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + if sys.version_info >= (3, 12): + from importlib import metadata + else: + import importlib_metadata as metadata + from collections.abc import Generator + from types import ModuleType + + from airflow.listeners.listener import ListenerManager + +log = logging.getLogger(__name__) + + +class AirflowPluginSource: + """Class used to define an AirflowPluginSource.""" + + def __str__(self): + raise NotImplementedError + + def __html__(self): + raise NotImplementedError + + +class PluginsDirectorySource(AirflowPluginSource): + """Class used to define Plugins loaded from Plugins Directory.""" + + def __init__(self, path, plugins_folder: str): + self.path = os.path.relpath(path, plugins_folder) + + def __str__(self): + return f"$PLUGINS_FOLDER/{self.path}" + + def __html__(self): + return f"$PLUGINS_FOLDER/{self.path}" + + +class EntryPointSource(AirflowPluginSource): + """Class used to define Plugins loaded from entrypoint.""" + + def __init__(self, entrypoint: metadata.EntryPoint, dist: metadata.Distribution): + self.dist = dist.metadata["Name"] # type: ignore[index] + self.version = dist.version + self.entrypoint = str(entrypoint) + + def __str__(self): + return f"{self.dist}=={self.version}: {self.entrypoint}" + + def __html__(self): + return f"{self.dist}=={self.version}: {self.entrypoint}" + + +class AirflowPluginException(Exception): + """Exception when loading plugin.""" + + +class AirflowPlugin: + """Class used to define AirflowPlugin.""" + + name: str | None = None + source: AirflowPluginSource | None = None + macros: list[Any] = [] + admin_views: list[Any] = [] + flask_blueprints: list[Any] = [] + fastapi_apps: list[Any] = [] + fastapi_root_middlewares: list[Any] = [] + external_views: list[Any] = [] + react_apps: list[Any] = [] + menu_links: list[Any] = [] + appbuilder_views: list[Any] = [] + appbuilder_menu_items: list[Any] = [] + + # A list of global operator extra links that can redirect users to + # external systems. These extra links will be available on the + # task page in the form of buttons. + # + # Note: the global operator extra link can be overridden at each + # operator level. + global_operator_extra_links: list[Any] = [] + + # A list of operator extra links to override or add operator links + # to existing Airflow Operators. + # These extra links will be available on the task page in form of + # buttons. + operator_extra_links: list[Any] = [] + + # A list of timetable classes that can be used for DAG scheduling. + timetables: list[Any] = [] + + # A list of listeners that can be used for tracking task and DAG states. + listeners: list[ModuleType | object] = [] + + # A list of hook lineage reader classes that can be used for reading lineage information from a hook. + hook_lineage_readers: list[Any] = [] + + # A list of priority weight strategy classes that can be used for calculating tasks weight priority. + priority_weight_strategies: list[Any] = [] + + @classmethod + def validate(cls): + """Validate if plugin has a name.""" + if not cls.name: + raise AirflowPluginException("Your plugin needs a name.") + + @classmethod + def on_load(cls, *args, **kwargs): + """ + Execute when the plugin is loaded; This method is only called once during runtime. + + :param args: If future arguments are passed in on call. + :param kwargs: If future arguments are passed in on call. + """ + + +def is_valid_plugin(plugin_obj) -> bool: + """ + Check whether a potential object is a subclass of the AirflowPlugin class. + + :param plugin_obj: potential subclass of AirflowPlugin + :return: Whether or not the obj is a valid subclass of + AirflowPlugin + """ + if not inspect.isclass(plugin_obj): + return False + + # Temporarily here, we use a name base checking instead of issubclass() because the shared library + # is accessed via different symlink paths in core (airflow._shared) and task sdk (airflow.sdk._shared). + # Python treats these as different modules, so the AirflowPlugin class has different identities in each context. + # Providers will typically inherit from SDK AirflowPlugin, so using issubclass() would fail when core tries + # to validate those plugins and provider plugins won't work in airflow core. + # For now, by validating by class name, we allow plugins defined with either + # core's or SDK's AirflowPlugin to be loaded. + is_airflow_plugin = any( + base.__name__ == "AirflowPlugin" and "plugins_manager" in base.__module__ + for base in plugin_obj.__mro__ + ) + + if is_airflow_plugin and plugin_obj.__name__ != "AirflowPlugin": + plugin_obj.validate() + return True + return False + + +def _load_entrypoint_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: + """ + Load and register plugins AirflowPlugin subclasses from the entrypoints. + + The entry_point group should be 'airflow.plugins'. + """ + from ..module_loading import entry_points_with_dist + + log.debug("Loading plugins from entrypoints") + + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} + for entry_point, dist in entry_points_with_dist("airflow.plugins"): + log.debug("Importing entry_point plugin %s", entry_point.name) + try: + plugin_class = entry_point.load() + if not is_valid_plugin(plugin_class): + continue + + plugin_instance: AirflowPlugin = plugin_class() + plugin_instance.source = EntryPointSource(entry_point, dist) + plugins.append(plugin_instance) + except Exception as e: + log.exception("Failed to import plugin %s", entry_point.name) + import_errors[entry_point.module] = str(e) + return plugins, import_errors + + +def _load_plugins_from_plugin_directory( + plugins_folder: str, + load_examples: bool = False, + example_plugins_module: str | None = None, + ignore_file_syntax: str = "glob", +) -> tuple[list[AirflowPlugin], dict[str, str]]: + """Load and register Airflow Plugins from plugins directory.""" + from ..module_loading import find_path_from_directory + + if not plugins_folder: + raise ValueError("Plugins folder is not set") + log.debug("Loading plugins from directory: %s", plugins_folder) + files = find_path_from_directory(plugins_folder, ".airflowignore", ignore_file_syntax) + plugin_search_locations: list[tuple[str, Generator[str, None, None]]] = [("", files)] + + if load_examples: + log.debug("Note: Loading plugins from examples as well: %s", plugins_folder) + import importlib + + example_plugins = importlib.import_module(example_plugins_module) + example_plugins_folder = next(iter(example_plugins.__path__)) + example_files = find_path_from_directory(example_plugins_folder, ".airflowignore") + plugin_search_locations.append((example_plugins.__name__, example_files)) + + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} + for module_prefix, plugin_files in plugin_search_locations: + for file_path in plugin_files: + path = Path(file_path) + if not path.is_file() or path.suffix != ".py": + continue + mod_name = f"{module_prefix}.{path.stem}" if module_prefix else path.stem + + try: + loader = importlib.machinery.SourceFileLoader(mod_name, file_path) + spec = importlib.util.spec_from_loader(mod_name, loader) + if not spec: + log.error("Could not load spec for module %s at %s", mod_name, file_path) + continue + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + loader.exec_module(mod) + + for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)): + plugin_instance: AirflowPlugin = mod_attr_value() + plugin_instance.source = PluginsDirectorySource(file_path, plugins_folder) + plugins.append(plugin_instance) + except Exception as e: + log.exception("Failed to import plugin %s", file_path) + import_errors[file_path] = str(e) + return plugins, import_errors + + +def make_module(name: str, objects: list[Any]) -> ModuleType | None: + """Create new module.""" + if not objects: + return None + log.debug("Creating module %s", name) + name = name.lower() + module = types.ModuleType(name) + module._name = name.split(".")[-1] # type: ignore + module._objects = objects # type: ignore + module.__dict__.update((o.__name__, o) for o in objects) + return module + + +def integrate_macros_plugins( + target_macros_module: ModuleType, macros_module_name_prefix: str, plugins: list[AirflowPlugin] +) -> None: + """ + Register macros from plugins onto the target macros module. + + For each plugin with macros, creates a submodule and attaches it to + the target module so macros can be accessed in templates as + ``{{ macros.plugin_name.macro_func() }}``. + """ + log.debug("Integrate Macros plugins") + + for plugin in plugins: + if plugin.name is None: + raise AirflowPluginException("Invalid plugin name") + + macros_module_instance = make_module(f"{macros_module_name_prefix}.{plugin.name}", plugin.macros) + + if macros_module_instance: + sys.modules[macros_module_instance.__name__] = macros_module_instance + # Register the newly created module on the provided macros module + # so it can be accessed when rendering templates. + setattr(target_macros_module, plugin.name, macros_module_instance) + + +def integrate_listener_plugins(listener_manager: ListenerManager, plugins: list[AirflowPlugin]) -> None: + """ + Register listeners from plugins with the listener manager. + + For each plugin with listeners, registers them with the provided + ListenerManager. + """ + for plugin in plugins: + if plugin.name is None: + raise AirflowPluginException("Invalid plugin name") + + for listener in plugin.listeners: + listener_manager.add_listener(listener) diff --git a/shared/plugins_manager/tests/conftest.py b/shared/plugins_manager/tests/conftest.py new file mode 100644 index 0000000000000..93aecf261843a --- /dev/null +++ b/shared/plugins_manager/tests/conftest.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +os.environ["_AIRFLOW__AS_LIBRARY"] = "true" diff --git a/shared/plugins_manager/tests/plugins_manager/__init__.py b/shared/plugins_manager/tests/plugins_manager/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/shared/plugins_manager/tests/plugins_manager/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/shared/plugins_manager/tests/plugins_manager/test_plugins_manager.py b/shared/plugins_manager/tests/plugins_manager/test_plugins_manager.py new file mode 100644 index 0000000000000..723c0ae4b54ec --- /dev/null +++ b/shared/plugins_manager/tests/plugins_manager/test_plugins_manager.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import logging +import sys +from unittest import mock + +import pytest + +from airflow_shared.plugins_manager import ( + EntryPointSource, + PluginsDirectorySource, + _load_entrypoint_plugins, +) + + +@pytest.fixture +def mock_metadata_distribution(mocker): + @contextlib.contextmanager + def wrapper(*args, **kwargs): + if sys.version_info < (3, 12): + patch_fq = "importlib_metadata.distributions" + else: + patch_fq = "importlib.metadata.distributions" + + with mock.patch(patch_fq, *args, **kwargs) as m: + yield m + + return wrapper + + +class TestPluginsDirectorySource: + def test_should_return_correct_path_name(self, tmp_path): + plugins_folder = str(tmp_path) + test_file = tmp_path / "test_plugins_manager.py" + test_file.write_text("# test file") + + source = PluginsDirectorySource(str(test_file), plugins_folder) + assert source.path == "test_plugins_manager.py" + assert str(source) == "$PLUGINS_FOLDER/test_plugins_manager.py" + assert source.__html__() == "$PLUGINS_FOLDER/test_plugins_manager.py" + + +class TestEntryPointSource: + def test_should_return_correct_source_details(self, mock_metadata_distribution): + mock_entrypoint = mock.Mock() + mock_entrypoint.name = "test-entrypoint-plugin" + mock_entrypoint.module = "module_name_plugin" + mock_entrypoint.group = "airflow.plugins" + + mock_dist = mock.Mock() + mock_dist.metadata = {"Name": "test-entrypoint-plugin"} + mock_dist.version = "1.0.0" + mock_dist.entry_points = [mock_entrypoint] + + with mock_metadata_distribution(return_value=[mock_dist]): + _load_entrypoint_plugins() + + source = EntryPointSource(mock_entrypoint, mock_dist) + assert str(mock_entrypoint) == source.entrypoint + assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source) + assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == source.__html__() + + +class TestPluginsManager: + def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_metadata_distribution, caplog): + """ + Test that Airflow does not raise an error if there is any Exception because of a plugin. + """ + mock_dist = mock.Mock() + mock_dist.metadata = {"Name": "test-dist"} + + mock_entrypoint = mock.Mock() + mock_entrypoint.name = "test-entrypoint" + mock_entrypoint.group = "airflow.plugins" + mock_entrypoint.module = "test.plugins.test_plugins_manager" + mock_entrypoint.load.side_effect = ImportError("my_fake_module not found") + mock_dist.entry_points = [mock_entrypoint] + + with ( + mock_metadata_distribution(return_value=[mock_dist]), + caplog.at_level(logging.ERROR, logger="airflow_shared.plugins_manager.plugins_manager"), + ): + _, import_errors = _load_entrypoint_plugins() + + received_logs = caplog.text + # Assert Traceback is shown too + assert "Traceback (most recent call last):" in received_logs + assert "my_fake_module not found" in received_logs + assert "Failed to import plugin test-entrypoint" in received_logs + assert ( + "test.plugins.test_plugins_manager", + "my_fake_module not found", + ) in import_errors.items() diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 3479c0c04c4b6..e1ece8f283110 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -127,6 +127,7 @@ path = "src/airflow/sdk/__init__.py" "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/sdk/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/sdk/_shared/secrets_masker" "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/sdk/_shared/timezones" +"../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager" [tool.hatch.build.targets.wheel] packages = ["src/airflow"] @@ -276,4 +277,5 @@ shared_distributions = [ "apache-airflow-shared-secrets-masker", "apache-airflow-shared-timezones", "apache-airflow-shared-observability", + "apache-airflow-shared-plugins-manager", ] diff --git a/task-sdk/src/airflow/sdk/_shared/plugins_manager b/task-sdk/src/airflow/sdk/_shared/plugins_manager new file mode 120000 index 0000000000000..366cca55e6ee5 --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/plugins_manager @@ -0,0 +1 @@ +../../../../../shared/plugins_manager/src/airflow_shared/plugins_manager \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index e910d8de6d770..2a6bf6f7994e4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -183,7 +183,7 @@ def __rich_repr__(self): def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? - from airflow.plugins_manager import integrate_macros_plugins + from airflow.sdk.plugins_manager import integrate_macros_plugins integrate_macros_plugins() diff --git a/task-sdk/src/airflow/sdk/plugins_manager.py b/task-sdk/src/airflow/sdk/plugins_manager.py new file mode 100644 index 0000000000000..17672b5ae70ff --- /dev/null +++ b/task-sdk/src/airflow/sdk/plugins_manager.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SDK wrapper for plugins manager.""" + +from __future__ import annotations + +import logging +from functools import cache +from typing import TYPE_CHECKING + +from airflow import settings +from airflow.observability.stats import Stats +from airflow.providers_manager import ProvidersManager +from airflow.sdk._shared.module_loading import import_string +from airflow.sdk._shared.plugins_manager import ( + AirflowPlugin, + _load_entrypoint_plugins, + _load_plugins_from_plugin_directory, + integrate_listener_plugins as _integrate_listener_plugins, + integrate_macros_plugins as _integrate_macros_plugins, + is_valid_plugin, +) +from airflow.sdk.configuration import conf + +if TYPE_CHECKING: + from airflow.listeners.listener import ListenerManager + +log = logging.getLogger(__name__) + + +def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: + """Load plugins from providers.""" + log.debug("Loading plugins from providers") + providers_manager = ProvidersManager() + providers_manager.initialize_providers_plugins() + + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} + for plugin in providers_manager.plugins: + log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class) + + try: + plugin_instance = import_string(plugin.plugin_class) + if is_valid_plugin(plugin_instance): + plugins.append(plugin_instance) + else: + log.warning("Plugin %s is not a valid plugin", plugin.name) + except ImportError: + log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class) + return plugins, import_errors + + +@cache +def _get_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: + """ + Load plugins from plugins directory and entrypoints. + + Plugins are only loaded if they have not been previously loaded. + """ + if not settings.PLUGINS_FOLDER: + raise ValueError("Plugins folder is not set") + + log.debug("Loading plugins") + + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} + loaded_plugins: set[str | None] = set() + + def __register_plugins(plugin_instances: list[AirflowPlugin], errors: dict[str, str]) -> None: + for plugin_instance in plugin_instances: + if plugin_instance.name in loaded_plugins: + return + + loaded_plugins.add(plugin_instance.name) + try: + plugin_instance.on_load() + plugins.append(plugin_instance) + except Exception as e: + log.exception("Failed to load plugin %s", plugin_instance.name) + name = str(plugin_instance.source) if plugin_instance.source else plugin_instance.name or "" + import_errors[name] = str(e) + import_errors.update(errors) + + with Stats.timer() as timer: + load_examples = conf.getboolean("core", "LOAD_EXAMPLES") + __register_plugins( + *_load_plugins_from_plugin_directory( + plugins_folder=settings.PLUGINS_FOLDER, + load_examples=load_examples, + example_plugins_module="airflow.example_dags.plugins" if load_examples else None, + ) + ) + __register_plugins(*_load_entrypoint_plugins()) + + if not settings.LAZY_LOAD_PROVIDERS: + __register_plugins(*_load_providers_plugins()) + + log.debug("Loading %d plugin(s) took %.2f seconds", len(plugins), timer.duration) + return plugins, import_errors + + +@cache +def integrate_macros_plugins() -> None: + """Integrates macro plugins.""" + from airflow.sdk.execution_time import macros + + plugins, _ = _get_plugins() + _integrate_macros_plugins( + target_macros_module=macros, + macros_module_name_prefix="airflow.sdk.execution_time.macros", + plugins=plugins, + ) + + +def integrate_listener_plugins(listener_manager: ListenerManager) -> None: + """Add listeners from plugins.""" + plugins, _ = _get_plugins() + _integrate_listener_plugins(listener_manager, plugins=plugins) diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index 5b29f12395e2d..f02b9ccebf91f 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -59,6 +59,7 @@ def test_airflow_sdk_no_unexpected_exports(): "yaml", "serde", "observability", + "plugins_manager", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}"