diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index 3261f073f194b..e2b4373302477 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -19,7 +19,7 @@ import logging from contextlib import AsyncExitStack, asynccontextmanager from functools import cache -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from urllib.parse import urlsplit from fastapi import FastAPI @@ -31,7 +31,6 @@ init_error_handlers, init_flask_plugins, init_middlewares, - init_ui_plugins, init_views, ) from airflow.api_fastapi.execution_api.app import create_task_execution_api_app @@ -99,7 +98,6 @@ def create_app(apps: str = "all") -> FastAPI: init_plugins(app) init_auth_manager(app) init_flask_plugins(app) - init_ui_plugins(app) init_views(app) # Core views need to be the last routes added - it has a catch all route init_error_handlers(app) init_middlewares(app) @@ -171,10 +169,9 @@ def init_plugins(app: FastAPI) -> None: """Integrate FastAPI app, middlewares and UI plugins.""" from airflow import plugins_manager - plugins_manager.initialize_fastapi_plugins() + apps, root_middlewares = plugins_manager.get_fastapi_plugins() - # After calling initialize_fastapi_plugins, fastapi_apps cannot be None anymore. - for subapp_dict in cast("list", plugins_manager.fastapi_apps): + for subapp_dict in apps: name = subapp_dict.get("name") subapp = subapp_dict.get("app") if subapp is None: @@ -194,8 +191,7 @@ def init_plugins(app: FastAPI) -> None: log.debug("Adding subapplication %s under prefix %s", name, url_prefix) app.mount(url_prefix, subapp) - # After calling initialize_fastapi_plugins, fastapi_root_middlewares cannot be None anymore. - for middleware_dict in cast("list", plugins_manager.fastapi_root_middlewares): + for middleware_dict in root_middlewares: name = middleware_dict.get("name") middleware = middleware_dict.get("middleware") args = middleware_dict.get("args", []) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/app.py b/airflow-core/src/airflow/api_fastapi/core_api/app.py index 1f370f5844d08..8a8f525265a2f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/app.py @@ -113,14 +113,10 @@ def init_flask_plugins(app: FastAPI) -> None: """Integrate Flask plugins (plugins from Airflow 2).""" from airflow import plugins_manager - plugins_manager.initialize_flask_plugins() + blueprints, appbuilder_views, appbuilder_menu_links = plugins_manager.get_flask_plugins() # If no Airflow 2.x plugin is in the environment, no need to go further - if ( - not plugins_manager.flask_blueprints - and not plugins_manager.flask_appbuilder_views - and not plugins_manager.flask_appbuilder_menu_links - ): + if not blueprints and not appbuilder_views and not appbuilder_menu_links: return from fastapi.middleware.wsgi import WSGIMiddleware @@ -190,10 +186,3 @@ def init_middlewares(app: FastAPI) -> None: from airflow.api_fastapi.auth.managers.simple.middleware import SimpleAllAdminMiddleware app.add_middleware(SimpleAllAdminMiddleware) - - -def init_ui_plugins(app: FastAPI) -> None: - """Initialize UI plugins.""" - from airflow import plugins_manager - - plugins_manager.initialize_ui_plugins() diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/plugins.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/plugins.py index 3b5368f50d232..57c27a99b6869 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/plugins.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/plugins.py @@ -75,13 +75,10 @@ def get_plugins( dependencies=[Depends(requires_access_view(AccessView.PLUGINS))], ) def import_errors() -> PluginImportErrorCollectionResponse: - plugins_manager.ensure_plugins_loaded() # make sure import_errors are loaded - + import_errors = plugins_manager.get_import_errors() return PluginImportErrorCollectionResponse.model_validate( { - "import_errors": [ - {"source": source, "error": error} for source, error in plugins_manager.import_errors.items() - ], - "total_entries": len(plugins_manager.import_errors), + "import_errors": [{"source": source, "error": error} for source, error in import_errors.items()], + "total_entries": len(import_errors), } ) diff --git a/airflow-core/src/airflow/cli/commands/plugins_command.py b/airflow-core/src/airflow/cli/commands/plugins_command.py index 29dd75674afe0..595ec7aba1056 100644 --- a/airflow-core/src/airflow/cli/commands/plugins_command.py +++ b/airflow-core/src/airflow/cli/commands/plugins_command.py @@ -16,35 +16,18 @@ # under the License. from __future__ import annotations -import inspect -from typing import Any - -from airflow import plugins_manager from airflow.cli.simple_table import AirflowConsole -from airflow.plugins_manager import PluginsDirectorySource, get_plugin_info +from airflow.plugins_manager import get_plugin_info from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.providers_configuration_loader import providers_configuration_loaded -def _get_name(class_like_object) -> str: - if isinstance(class_like_object, (str, PluginsDirectorySource)): - return str(class_like_object) - if inspect.isclass(class_like_object): - return class_like_object.__name__ - return class_like_object.__class__.__name__ - - -def _join_plugins_names(value: list[Any] | Any) -> str: - value = value if isinstance(value, list) else [value] - return ",".join(_get_name(v) for v in value) - - @suppress_logs_and_warning @providers_configuration_loaded def dump_plugins(args): """Dump plugins information.""" plugins_info: list[dict[str, str]] = get_plugin_info() - if not plugins_manager.plugins: + if not plugins_info: print("No plugins loaded") return diff --git a/airflow-core/src/airflow/lineage/hook.py b/airflow-core/src/airflow/lineage/hook.py index 41c0e2bd4bc93..7c22a36700645 100644 --- a/airflow-core/src/airflow/lineage/hook.py +++ b/airflow-core/src/airflow/lineage/hook.py @@ -337,7 +337,6 @@ def get_hook_lineage_collector() -> HookLineageCollector: """Get singleton lineage collector.""" from airflow import plugins_manager - plugins_manager.initialize_hook_lineage_readers_plugins() - if plugins_manager.hook_lineage_reader_classes: + if plugins_manager.get_hook_lineage_readers_plugins(): return HookLineageCollector() return NoOpCollector() diff --git a/airflow-core/src/airflow/plugins_manager.py b/airflow-core/src/airflow/plugins_manager.py index c32bcae8ea5c2..8fdb3f0635638 100644 --- a/airflow-core/src/airflow/plugins_manager.py +++ b/airflow-core/src/airflow/plugins_manager.py @@ -27,6 +27,7 @@ import sys import types from collections.abc import Iterable +from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any @@ -43,10 +44,10 @@ if TYPE_CHECKING: from airflow.lineage.hook import HookLineageReader - try: + if sys.version_info >= (3, 12): + from importlib import metadata + else: import importlib_metadata as metadata - except ImportError: - from importlib import metadata # type: ignore[no-redef] from collections.abc import Generator from types import ModuleType @@ -55,55 +56,6 @@ log = logging.getLogger(__name__) -import_errors: dict[str, str] = {} - -plugins: list[AirflowPlugin] | None = None -loaded_plugins: set[str] = set() - -# Plugin components to integrate as modules -macros_modules: list[Any] | None = None - -# Plugin components to integrate directly -admin_views: list[Any] | None = None -flask_blueprints: list[Any] | None = None -fastapi_apps: list[Any] | None = None -fastapi_root_middlewares: list[Any] | None = None -external_views: list[Any] | None = None -react_apps: list[Any] | None = None -menu_links: list[Any] | None = None -flask_appbuilder_views: list[Any] | None = None -flask_appbuilder_menu_links: list[Any] | None = None -global_operator_extra_links: list[Any] | None = None -operator_extra_links: list[Any] | None = None -registered_operator_link_classes: dict[str, type] | None = None -timetable_classes: dict[str, type[Timetable]] | None = None -hook_lineage_reader_classes: list[type[HookLineageReader]] | None = None -priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None -""" -Mapping of class names to class of OperatorLinks registered by plugins. - -Used by the DAG serialization code to only allow specific classes to be created -during deserialization -""" -PLUGINS_ATTRIBUTES_TO_DUMP = { - "macros", - "admin_views", - "flask_blueprints", - "fastapi_apps", - "fastapi_root_middlewares", - "external_views", - "react_apps", - "menu_links", - "appbuilder_views", - "appbuilder_menu_items", - "global_operator_extra_links", - "operator_extra_links", - "source", - "timetables", - "listeners", - "priority_weight_strategies", -} - class AirflowPluginSource: """Class used to define an AirflowPluginSource.""" @@ -205,7 +157,7 @@ def on_load(cls, *args, **kwargs): """ -def is_valid_plugin(plugin_obj): +def is_valid_plugin(plugin_obj) -> bool: """ Check whether a potential object is a subclass of the AirflowPlugin class. @@ -219,27 +171,11 @@ def is_valid_plugin(plugin_obj): and (plugin_obj is not AirflowPlugin) ): plugin_obj.validate() - return plugin_obj not in plugins + return True return False -def register_plugin(plugin_instance): - """ - Start plugin load and register it after success initialization. - - If plugin is already registered, do nothing. - - :param plugin_instance: subclass of AirflowPlugin - """ - if plugin_instance.name in loaded_plugins: - return - - loaded_plugins.add(plugin_instance.name) - plugin_instance.on_load() - plugins.append(plugin_instance) - - -def load_entrypoint_plugins(): +def _load_entrypoint_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: """ Load and register plugins AirflowPlugin subclasses from the entrypoints. @@ -247,6 +183,8 @@ def load_entrypoint_plugins(): """ log.debug("Loading plugins from entrypoints") + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} for entry_point, dist in entry_points_with_dist("airflow.plugins"): log.debug("Importing entry_point plugin %s", entry_point.name) try: @@ -254,28 +192,33 @@ def load_entrypoint_plugins(): if not is_valid_plugin(plugin_class): continue - plugin_instance = plugin_class() + plugin_instance: AirflowPlugin = plugin_class() plugin_instance.source = EntryPointSource(entry_point, dist) - register_plugin(plugin_instance) + plugins.append(plugin_instance) except Exception as e: log.exception("Failed to import plugin %s", entry_point.name) import_errors[entry_point.module] = str(e) + return plugins, import_errors -def load_plugins_from_plugin_directory(): +def _load_plugins_from_plugin_directory() -> tuple[list[AirflowPlugin], dict[str, str]]: """Load and register Airflow Plugins from plugins directory.""" + if settings.PLUGINS_FOLDER is None: + raise ValueError("Plugins folder is not set") log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER) files = find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore") plugin_search_locations: list[tuple[str, Generator[str, None, None]]] = [("", files)] if conf.getboolean("core", "LOAD_EXAMPLES"): log.debug("Note: Loading plugins from examples as well: %s", settings.PLUGINS_FOLDER) - from airflow.example_dags import plugins + from airflow.example_dags import plugins as example_plugins - example_plugins_folder = next(iter(plugins.__path__)) + example_plugins_folder = next(iter(example_plugins.__path__)) example_files = find_path_from_directory(example_plugins_folder, ".airflowignore") - plugin_search_locations.append((plugins.__name__, example_files)) + plugin_search_locations.append((example_plugins.__name__, example_files)) + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} for module_prefix, plugin_files in plugin_search_locations: for file_path in plugin_files: path = Path(file_path) @@ -286,39 +229,47 @@ def load_plugins_from_plugin_directory(): try: loader = importlib.machinery.SourceFileLoader(mod_name, file_path) spec = importlib.util.spec_from_loader(mod_name, loader) + if not spec: + log.error("Could not load spec for module %s at %s", mod_name, file_path) + continue mod = importlib.util.module_from_spec(spec) sys.modules[spec.name] = mod loader.exec_module(mod) for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)): - plugin_instance = mod_attr_value() + plugin_instance: AirflowPlugin = mod_attr_value() plugin_instance.source = PluginsDirectorySource(file_path) - register_plugin(plugin_instance) + plugins.append(plugin_instance) except Exception as e: log.exception("Failed to import plugin %s", file_path) import_errors[file_path] = str(e) + return plugins, import_errors -def load_providers_plugins(): +def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: from airflow.providers_manager import ProvidersManager log.debug("Loading plugins from providers") providers_manager = ProvidersManager() providers_manager.initialize_providers_plugins() + + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} for plugin in providers_manager.plugins: log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class) try: plugin_instance = import_string(plugin.plugin_class) if is_valid_plugin(plugin_instance): - register_plugin(plugin_instance) + plugins.append(plugin_instance) else: log.warning("Plugin %s is not a valid plugin", plugin.name) except ImportError: log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class) + return plugins, import_errors -def make_module(name: str, objects: list[Any]): +def make_module(name: str, objects: list[Any]) -> ModuleType | None: """Create new module.""" if not objects: return None @@ -331,64 +282,69 @@ def make_module(name: str, objects: list[Any]): return module -def ensure_plugins_loaded(): +def ensure_plugins_loaded() -> None: """ Load plugins from plugins directory and entrypoints. Plugins are only loaded if they have not been previously loaded. """ - from airflow.observability.stats import Stats + _get_plugins() - global plugins - if plugins is not None: - log.debug("Plugins are already loaded. Skipping.") - return +@cache +def _get_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]: + """ + Load plugins from plugins directory and entrypoints. + + Plugins are only loaded if they have not been previously loaded. + """ + from airflow.observability.stats import Stats if not settings.PLUGINS_FOLDER: raise ValueError("Plugins folder is not set") log.debug("Loading plugins") - with Stats.timer() as timer: - plugins = [] + plugins: list[AirflowPlugin] = [] + import_errors: dict[str, str] = {} + loaded_plugins: set[str | None] = set() - load_plugins_from_plugin_directory() - load_entrypoint_plugins() - - if not settings.LAZY_LOAD_PROVIDERS: - load_providers_plugins() - - if plugins: - log.debug("Loading %d plugin(s) took %.2f seconds", len(plugins), timer.duration) + def __register_plugins(plugin_instances: list[AirflowPlugin], errors: dict[str, str]) -> None: + for plugin_instance in plugin_instances: + if plugin_instance.name in loaded_plugins: + return + loaded_plugins.add(plugin_instance.name) + try: + plugin_instance.on_load() + plugins.append(plugin_instance) + except Exception as e: + log.exception("Failed to load plugin %s", plugin_instance.name) + name = str(plugin_instance.source) if plugin_instance.source else plugin_instance.name or "" + import_errors[name] = str(e) + import_errors.update(errors) -def initialize_ui_plugins(): - """Collect extension points for the UI.""" - global external_views - global react_apps + with Stats.timer() as timer: + __register_plugins(*_load_plugins_from_plugin_directory()) + __register_plugins(*_load_entrypoint_plugins()) - if external_views is not None and react_apps is not None: - return + if not settings.LAZY_LOAD_PROVIDERS: + __register_plugins(*_load_providers_plugins()) - ensure_plugins_loaded() + log.debug("Loading %d plugin(s) took %.2f seconds", len(plugins), timer.duration) + return plugins, import_errors - if plugins is None: - raise AirflowPluginException("Can't load plugins.") +@cache +def _get_ui_plugins() -> tuple[list[Any], list[Any]]: + """Collect extension points for the UI.""" log.debug("Initialize UI plugin") - seen_url_route = {} - external_views = [] - react_apps = [] - - def _remove_list_item(lst, item): - # Mutate in place the plugin's external views and react apps list to remove the invalid items - # because some function still access these plugin's attribute and not the - # global variables `external_views` `react_apps`. (get_plugin_info, for example) - lst.remove(item) + seen_url_routes: dict[str, str | None] = {} - for plugin in plugins: + external_views: list[Any] = [] + react_apps: list[Any] = [] + for plugin in _get_plugins()[0]: external_views_to_remove = [] react_apps_to_remove = [] for external_view in plugin.external_views: @@ -402,18 +358,18 @@ def _remove_list_item(lst, item): url_route = external_view.get("url_route") if url_route is None: continue - if url_route in seen_url_route: + if url_route in seen_url_routes: log.warning( "Plugin '%s' has an external view with an URL route '%s' " "that conflicts with another plugin '%s'. The view will not be loaded.", plugin.name, url_route, - seen_url_route[url_route], + seen_url_routes[url_route], ) external_views_to_remove.append(external_view) continue external_views.append(external_view) - seen_url_route[url_route] = plugin.name + seen_url_routes[url_route] = plugin.name for react_app in plugin.react_apps: if not isinstance(react_app, dict): @@ -426,50 +382,35 @@ def _remove_list_item(lst, item): url_route = react_app.get("url_route") if url_route is None: continue - if url_route in seen_url_route: + if url_route in seen_url_routes: log.warning( "Plugin '%s' has a React App with an URL route '%s' " "that conflicts with another plugin '%s'. The React App will not be loaded.", plugin.name, url_route, - seen_url_route[url_route], + seen_url_routes[url_route], ) react_apps_to_remove.append(react_app) continue react_apps.append(react_app) - seen_url_route[url_route] = plugin.name + seen_url_routes[url_route] = plugin.name for item in external_views_to_remove: - _remove_list_item(plugin.external_views, item) + plugin.external_views.remove(item) for item in react_apps_to_remove: - _remove_list_item(plugin.react_apps, item) - + plugin.react_apps.remove(item) + return external_views, react_apps -def initialize_flask_plugins(): - """Collect flask extension points for WEB UI (legacy).""" - global flask_blueprints - global flask_appbuilder_views - global flask_appbuilder_menu_links - - if ( - flask_blueprints is not None - and flask_appbuilder_views is not None - and flask_appbuilder_menu_links is not None - ): - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") +@cache +def get_flask_plugins() -> tuple[list[Any], list[Any], list[Any]]: + """Collect and get flask extension points for WEB UI (legacy).""" log.debug("Initialize legacy Web UI plugin") - flask_blueprints = [] - flask_appbuilder_views = [] - flask_appbuilder_menu_links = [] - - for plugin in plugins: + flask_appbuilder_views: list[Any] = [] + flask_appbuilder_menu_links: list[Any] = [] + flask_blueprints: list[Any] = [] + for plugin in _get_plugins()[0]: flask_appbuilder_views.extend(plugin.appbuilder_views) flask_appbuilder_menu_links.extend(plugin.appbuilder_menu_items) flask_blueprints.extend([{"name": plugin.name, "blueprint": bp} for bp in plugin.flask_blueprints]) @@ -482,130 +423,82 @@ def initialize_flask_plugins(): "Please contact the author of the plugin.", plugin.name, ) + return flask_blueprints, flask_appbuilder_views, flask_appbuilder_menu_links -def initialize_fastapi_plugins(): +@cache +def get_fastapi_plugins() -> tuple[list[Any], list[Any]]: """Collect extension points for the API.""" - global fastapi_apps - global fastapi_root_middlewares - - if fastapi_apps is not None and fastapi_root_middlewares is not None: - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") - log.debug("Initialize FastAPI plugins") - fastapi_apps = [] - fastapi_root_middlewares = [] - - for plugin in plugins: + fastapi_apps: list[Any] = [] + fastapi_root_middlewares: list[Any] = [] + for plugin in _get_plugins()[0]: fastapi_apps.extend(plugin.fastapi_apps) fastapi_root_middlewares.extend(plugin.fastapi_root_middlewares) + return fastapi_apps, fastapi_root_middlewares -def initialize_extra_operators_links_plugins(): - """Create modules for loaded extension from extra operators links plugins.""" - global global_operator_extra_links - global operator_extra_links - global registered_operator_link_classes - - if ( - global_operator_extra_links is not None - and operator_extra_links is not None - and registered_operator_link_classes is not None - ): - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") - +@cache +def _get_extra_operators_links_plugins() -> tuple[list[Any], list[Any]]: + """Create and get modules for loaded extension from extra operators links plugins.""" log.debug("Initialize extra operators links plugins") - global_operator_extra_links = [] - operator_extra_links = [] - registered_operator_link_classes = {} - - for plugin in plugins: + global_operator_extra_links: list[Any] = [] + operator_extra_links: list[Any] = [] + for plugin in _get_plugins()[0]: global_operator_extra_links.extend(plugin.global_operator_extra_links) operator_extra_links.extend(list(plugin.operator_extra_links)) + return global_operator_extra_links, operator_extra_links - registered_operator_link_classes.update( - {qualname(link.__class__): link.__class__ for link in plugin.operator_extra_links} - ) +def get_global_operator_extra_links() -> list[Any]: + """Get global operator extra links registered by plugins.""" + return _get_extra_operators_links_plugins()[0] -def initialize_timetables_plugins(): - """Collect timetable classes registered by plugins.""" - global timetable_classes - if timetable_classes is not None: - return +def get_operator_extra_links() -> list[Any]: + """Get operator extra links registered by plugins.""" + return _get_extra_operators_links_plugins()[1] - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") +@cache +def get_timetables_plugins() -> dict[str, type[Timetable]]: + """Collect and get timetable classes registered by plugins.""" log.debug("Initialize extra timetables plugins") - timetable_classes = { + return { qualname(timetable_class): timetable_class - for plugin in plugins + for plugin in _get_plugins()[0] for timetable_class in plugin.timetables } -def initialize_hook_lineage_readers_plugins(): - """Collect hook lineage reader classes registered by plugins.""" - global hook_lineage_reader_classes - - if hook_lineage_reader_classes is not None: - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") - +@cache +def get_hook_lineage_readers_plugins() -> list[type[HookLineageReader]]: + """Collect and get hook lineage reader classes registered by plugins.""" log.debug("Initialize hook lineage readers plugins") + result: list[type[HookLineageReader]] = [] - hook_lineage_reader_classes = [] - for plugin in plugins: - hook_lineage_reader_classes.extend(plugin.hook_lineage_readers) + for plugin in _get_plugins()[0]: + result.extend(plugin.hook_lineage_readers) + return result +@cache def integrate_macros_plugins() -> None: """Integrates macro plugins.""" - global macros_modules - from airflow.sdk.execution_time import macros - if macros_modules is not None: - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") - log.debug("Integrate Macros plugins") - macros_modules = [] - - for plugin in plugins: + for plugin in _get_plugins()[0]: if plugin.name is None: raise AirflowPluginException("Invalid plugin name") macros_module = make_module(f"airflow.sdk.execution_time.macros.{plugin.name}", plugin.macros) if macros_module: - macros_modules.append(macros_module) sys.modules[macros_module.__name__] = macros_module # Register the newly created module on airflow.macros such that it # can be accessed when rendering templates. @@ -614,15 +507,12 @@ def integrate_macros_plugins() -> None: def integrate_listener_plugins(listener_manager: ListenerManager) -> None: """Add listeners from plugins.""" - ensure_plugins_loaded() - - if plugins: - for plugin in plugins: - if plugin.name is None: - raise AirflowPluginException("Invalid plugin name") + for plugin in _get_plugins()[0]: + if plugin.name is None: + raise AirflowPluginException("Invalid plugin name") - for listener in plugin.listeners: - listener_manager.add_listener(listener) + for listener in plugin.listeners: + listener_manager.add_listener(listener) def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str, Any]]: @@ -631,79 +521,88 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str :param attrs_to_dump: A list of plugin attributes to dump """ - ensure_plugins_loaded() - integrate_macros_plugins() - initialize_flask_plugins() - initialize_fastapi_plugins() - initialize_ui_plugins() - initialize_extra_operators_links_plugins() + get_flask_plugins() + get_fastapi_plugins() + get_global_operator_extra_links() + get_operator_extra_links() + _get_ui_plugins() if not attrs_to_dump: - attrs_to_dump = PLUGINS_ATTRIBUTES_TO_DUMP + attrs_to_dump = { + "macros", + "admin_views", + "flask_blueprints", + "fastapi_apps", + "fastapi_root_middlewares", + "external_views", + "react_apps", + "menu_links", + "appbuilder_views", + "appbuilder_menu_items", + "global_operator_extra_links", + "operator_extra_links", + "source", + "timetables", + "listeners", + "priority_weight_strategies", + } plugins_info = [] - if plugins: - for plugin in plugins: - info: dict[str, Any] = {"name": plugin.name} - for attr in attrs_to_dump: - if attr in ("global_operator_extra_links", "operator_extra_links"): - info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)] - elif attr in ("macros", "timetables", "priority_weight_strategies"): - info[attr] = [qualname(d) for d in getattr(plugin, attr)] - elif attr == "listeners": - # listeners may be modules or class instances - info[attr] = [ - d.__name__ if inspect.ismodule(d) else qualname(d) for d in getattr(plugin, attr) - ] - elif attr == "appbuilder_views": - info[attr] = [ - {**d, "view": qualname(d["view"].__class__) if "view" in d else None} - for d in getattr(plugin, attr) - ] - elif attr == "flask_blueprints": - info[attr] = [ - f"<{qualname(d.__class__)}: name={d.name!r} import_name={d.import_name!r}>" - for d in getattr(plugin, attr) - ] - elif attr == "fastapi_apps": - info[attr] = [ - {**d, "app": qualname(d["app"].__class__) if "app" in d else None} - for d in getattr(plugin, attr) - ] - elif attr == "fastapi_root_middlewares": - # remove args and kwargs from plugin info to hide potentially sensitive info. - info[attr] = [ - { - k: (v if k != "middleware" else qualname(middleware_dict["middleware"])) - for k, v in middleware_dict.items() - if k not in ("args", "kwargs") - } - for middleware_dict in getattr(plugin, attr) - ] - else: - info[attr] = getattr(plugin, attr) - plugins_info.append(info) + for plugin in _get_plugins()[0]: + info: dict[str, Any] = {"name": plugin.name} + for attr in attrs_to_dump: + if attr in ("global_operator_extra_links", "operator_extra_links"): + info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)] + elif attr in ("macros", "timetables", "priority_weight_strategies"): + info[attr] = [qualname(d) for d in getattr(plugin, attr)] + elif attr == "listeners": + # listeners may be modules or class instances + info[attr] = [d.__name__ if inspect.ismodule(d) else qualname(d) for d in plugin.listeners] + elif attr == "appbuilder_views": + info[attr] = [ + {**d, "view": qualname(d["view"].__class__) if "view" in d else None} + for d in plugin.appbuilder_views + ] + elif attr == "flask_blueprints": + info[attr] = [ + f"<{qualname(d.__class__)}: name={d.name!r} import_name={d.import_name!r}>" + for d in plugin.flask_blueprints + ] + elif attr == "fastapi_apps": + info[attr] = [ + {**d, "app": qualname(d["app"].__class__) if "app" in d else None} + for d in plugin.fastapi_apps + ] + elif attr == "fastapi_root_middlewares": + # remove args and kwargs from plugin info to hide potentially sensitive info. + info[attr] = [ + { + k: (v if k != "middleware" else qualname(middleware_dict["middleware"])) + for k, v in middleware_dict.items() + if k not in ("args", "kwargs") + } + for middleware_dict in plugin.fastapi_root_middlewares + ] + else: + info[attr] = getattr(plugin, attr) + plugins_info.append(info) return plugins_info -def initialize_priority_weight_strategy_plugins(): - """Collect priority weight strategy classes registered by plugins.""" - global priority_weight_strategy_classes - - if priority_weight_strategy_classes is not None: - return - - ensure_plugins_loaded() - - if plugins is None: - raise AirflowPluginException("Can't load plugins.") - +@cache +def get_priority_weight_strategy_plugins() -> dict[str, type[PriorityWeightStrategy]]: + """Collect and get priority weight strategy classes registered by plugins.""" log.debug("Initialize extra priority weight strategy plugins") plugins_priority_weight_strategy_classes = { qualname(priority_weight_strategy_class): priority_weight_strategy_class - for plugin in plugins + for plugin in _get_plugins()[0] for priority_weight_strategy_class in plugin.priority_weight_strategies } - priority_weight_strategy_classes = { + return { **airflow_priority_weight_strategies, **plugins_priority_weight_strategy_classes, } + + +def get_import_errors() -> dict[str, str]: + """Get import errors encountered during plugin loading.""" + return _get_plugins()[1] diff --git a/airflow-core/src/airflow/serialization/definitions/baseoperator.py b/airflow-core/src/airflow/serialization/definitions/baseoperator.py index 13ea87ae47cc0..20cac69435fac 100644 --- a/airflow-core/src/airflow/serialization/definitions/baseoperator.py +++ b/airflow-core/src/airflow/serialization/definitions/baseoperator.py @@ -24,7 +24,6 @@ import methodtools -from airflow.exceptions import AirflowException from airflow.serialization.definitions.node import DAGNode from airflow.serialization.definitions.param import SerializedParamsDict from airflow.serialization.enums import DagAttributeTypes @@ -262,10 +261,7 @@ def global_operator_extra_link_dict(self) -> dict[str, Any]: """All global extra links.""" from airflow import plugins_manager - plugins_manager.initialize_extra_operators_links_plugins() - if plugins_manager.global_operator_extra_links is None: - raise AirflowException("Can't load operators") - return {link.name: link for link in plugins_manager.global_operator_extra_links} + return {link.name: link for link in plugins_manager.get_global_operator_extra_links()} @functools.cached_property def extra_links(self) -> list[str]: diff --git a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py index 561a23f0b9ad1..1cf6d357e651a 100644 --- a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py +++ b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py @@ -28,7 +28,7 @@ import structlog from sqlalchemy.orm import Session -from airflow.exceptions import AirflowException, NotMapped +from airflow.exceptions import NotMapped from airflow.sdk import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator from airflow.serialization.definitions.baseoperator import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator @@ -372,11 +372,9 @@ def operator_extra_link_dict(self) -> dict[str, XComOperatorLink]: op_extra_links_from_plugin: dict[str, Any] = {} from airflow import plugins_manager - plugins_manager.initialize_extra_operators_links_plugins() - if plugins_manager.operator_extra_links is None: - raise AirflowException("Can't load operators") + operator_extra_links = plugins_manager.get_operator_extra_links() operator_class_type = self.operator_class["task_type"] # type: ignore - for ope in plugins_manager.operator_extra_links: + for ope in operator_extra_links: if ope.operators and any(operator_class_type in cls.__name__ for cls in ope.operators): op_extra_links_from_plugin.update({ope.name: ope}) @@ -391,10 +389,8 @@ def global_operator_extra_link_dict(self) -> dict[str, Any]: """Returns dictionary of all global extra links.""" from airflow import plugins_manager - plugins_manager.initialize_extra_operators_links_plugins() - if plugins_manager.global_operator_extra_links is None: - raise AirflowException("Can't load operators") - return {link.name: link for link in plugins_manager.global_operator_extra_links} + global_operator_extra_links = plugins_manager.get_global_operator_extra_links() + return {link.name: link for link in global_operator_extra_links} @functools.cached_property def extra_links(self) -> list[str]: diff --git a/airflow-core/src/airflow/serialization/helpers.py b/airflow-core/src/airflow/serialization/helpers.py index 3af7f3c07e7f1..723c113709a87 100644 --- a/airflow-core/src/airflow/serialization/helpers.py +++ b/airflow-core/src/airflow/serialization/helpers.py @@ -116,10 +116,9 @@ def find_registered_custom_timetable(importable_string: str) -> type[CoreTimetab """Find a user-defined custom timetable class registered via a plugin.""" from airflow import plugins_manager - plugins_manager.initialize_timetables_plugins() - if plugins_manager.timetable_classes is not None: - with contextlib.suppress(KeyError): - return plugins_manager.timetable_classes[importable_string] + timetable_classes = plugins_manager.get_timetables_plugins() + with contextlib.suppress(KeyError): + return timetable_classes[importable_string] raise TimetableNotRegistered(importable_string) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 170879c5339ed..53b6377b96e4e 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -131,11 +131,7 @@ def _get_registered_priority_weight_strategy( if importable_string in airflow_priority_weight_strategies: return airflow_priority_weight_strategies[importable_string] - plugins_manager.initialize_priority_weight_strategy_plugins() - if plugins_manager.priority_weight_strategy_classes: - return plugins_manager.priority_weight_strategy_classes.get(importable_string) - else: - return None + return plugins_manager.get_priority_weight_strategy_plugins().get(importable_string) class _PartitionMapperNotFound(ValueError): @@ -1159,12 +1155,7 @@ def populate_operator( if cls._load_operator_extra_links: from airflow import plugins_manager - plugins_manager.initialize_extra_operators_links_plugins() - - if plugins_manager.operator_extra_links is None: - raise AirflowException("Can not load plugins") - - for ope in plugins_manager.operator_extra_links: + for ope in plugins_manager.get_operator_extra_links(): for operator in ope.operators: if ( operator.__name__ == encoded_op["task_type"] @@ -1545,10 +1536,7 @@ def _deserialize_operator_extra_links( """ from airflow import plugins_manager - plugins_manager.initialize_extra_operators_links_plugins() - - if plugins_manager.registered_operator_link_classes is None: - raise AirflowException("Can't load plugins") + plugins_manager.get_operator_extra_links() op_predefined_extra_links = {} for name, xcom_key in encoded_op_links.items(): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index 597c1f819874a..3720994e04df7 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -82,11 +82,10 @@ def custom_timetable_plugin(monkeypatch): timetable_class_name = qualname(CustomTimetable) existing_timetables = getattr(plugins_manager, "timetable_classes", None) or {} - monkeypatch.setattr(plugins_manager, "initialize_timetables_plugins", lambda: None) monkeypatch.setattr( plugins_manager, - "timetable_classes", - {**existing_timetables, timetable_class_name: CustomTimetable}, + "get_timetables_plugins", + lambda: {**existing_timetables, timetable_class_name: CustomTimetable}, ) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_plugins.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_plugins.py index 348b8d9526ede..4220837ebb1ba 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_plugins.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_plugins.py @@ -162,8 +162,8 @@ def test_invalid_external_view_destination_should_log_warning_and_continue(self, @skip_if_force_lowest_dependencies_marker class TestGetPluginImportErrors: @patch( - "airflow.plugins_manager.import_errors", - new={"plugins/test_plugin.py": "something went wrong"}, + "airflow.plugins_manager.get_import_errors", + new=lambda: {"plugins/test_plugin.py": "something went wrong"}, ) def test_should_respond_200(self, test_client, session): with assert_queries_count(2): diff --git a/airflow-core/tests/unit/api_fastapi/test_app.py b/airflow-core/tests/unit/api_fastapi/test_app.py index 1e0925f8e7349..1eb692e186467 100644 --- a/airflow-core/tests/unit/api_fastapi/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/test_app.py @@ -113,7 +113,7 @@ def test_catch_all_route_last(client): ) def test_plugin_with_invalid_url_prefix(caplog, fastapi_apps, expected_message, invalid_path): app = FastAPI() - with mock.patch.object(plugins_manager, "fastapi_apps", fastapi_apps): + with mock.patch.object(plugins_manager, "get_fastapi_plugins", return_value=(fastapi_apps, [])): app_module.init_plugins(app) assert any(expected_message in rec.message for rec in caplog.records) diff --git a/airflow-core/tests/unit/plugins/test_plugins_manager.py b/airflow-core/tests/unit/plugins/test_plugins_manager.py index c463f6a62739e..2b872f23c0bd3 100644 --- a/airflow-core/tests/unit/plugins/test_plugins_manager.py +++ b/airflow-core/tests/unit/plugins/test_plugins_manager.py @@ -75,8 +75,7 @@ class TestPluginsManager: def clean_plugins(self): from airflow import plugins_manager - plugins_manager.loaded_plugins = set() - plugins_manager.plugins = [] + plugins_manager._get_plugins.cache_clear() def test_no_log_when_no_plugins(self, caplog): with mock_plugin_manager(plugins=[]): @@ -89,33 +88,37 @@ def test_no_log_when_no_plugins(self, caplog): def test_loads_filesystem_plugins(self, caplog): from airflow import plugins_manager - with mock.patch("airflow.plugins_manager.plugins", []): - plugins_manager.load_plugins_from_plugin_directory() + plugins, import_errors = plugins_manager._load_plugins_from_plugin_directory() - assert len(plugins_manager.plugins) == 10 - for plugin in plugins_manager.plugins: - if "AirflowTestOnLoadPlugin" in str(plugin): - assert plugin.name == "postload" - break - else: - pytest.fail("Wasn't able to find a registered `AirflowTestOnLoadPlugin`") + assert len(plugins) == 10 + assert not import_errors + for plugin in plugins: + if "AirflowTestOnLoadPlugin" in str(plugin): + assert plugin.name == "preload" # on_init() is not called here + break + else: + pytest.fail("Wasn't able to find a registered `AirflowTestOnLoadPlugin`") - assert caplog.record_tuples == [] + assert caplog.record_tuples == [] def test_loads_filesystem_plugins_exception(self, caplog, tmp_path): from airflow import plugins_manager - with mock.patch("airflow.plugins_manager.plugins", []): - (tmp_path / "testplugin.py").write_text(ON_LOAD_EXCEPTION_PLUGIN) + (tmp_path / "testplugin.py").write_text(ON_LOAD_EXCEPTION_PLUGIN) - with conf_vars({("core", "plugins_folder"): os.fspath(tmp_path)}): - plugins_manager.load_plugins_from_plugin_directory() + with ( + conf_vars({("core", "plugins_folder"): os.fspath(tmp_path)}), + mock.patch("airflow.plugins_manager._load_entrypoint_plugins", return_value=([], [])), + mock.patch("airflow.plugins_manager._load_providers_plugins", return_value=([], [])), + ): + plugins, import_errors = plugins_manager._get_plugins() - assert len(plugins_manager.plugins) == 3 # three are loaded from examples + assert len(plugins) == 3 # three are loaded from examples + assert len(import_errors) == 1 - received_logs = caplog.text - assert "Failed to import plugin" in received_logs - assert "testplugin.py" in received_logs + received_logs = caplog.text + assert "Failed to load plugin" in received_logs + assert "testplugin.py" in received_logs def test_should_warning_about_incompatible_plugins(self, caplog): class AirflowAdminViewsPlugin(AirflowPlugin): @@ -134,7 +137,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): ): from airflow import plugins_manager - plugins_manager.initialize_flask_plugins() + plugins_manager.get_flask_plugins() assert caplog.record_tuples == [ ( @@ -169,14 +172,16 @@ class TestPluginB(AirflowPlugin): ): from airflow import plugins_manager - plugins_manager.initialize_ui_plugins() + external_views, react_apps = plugins_manager._get_ui_plugins() # Verify that the conflicting external view and react app are not loaded - plugin_b = next(plugin for plugin in plugins_manager.plugins if plugin.name == "test_plugin_b") + plugin_b = next( + plugin for plugin in plugins_manager._get_plugins()[0] if plugin.name == "test_plugin_b" + ) assert plugin_b.external_views == [] assert plugin_b.react_apps == [] - assert len(plugins_manager.external_views) == 1 - assert len(plugins_manager.react_apps) == 0 + assert len(external_views) == 1 + assert len(react_apps) == 0 def test_should_warning_about_external_views_or_react_app_wrong_object(self, caplog): class TestPluginA(AirflowPlugin): @@ -191,14 +196,16 @@ class TestPluginA(AirflowPlugin): ): from airflow import plugins_manager - plugins_manager.initialize_ui_plugins() + external_views, react_apps = plugins_manager._get_ui_plugins() # Verify that the conflicting external view and react app are not loaded - plugin_a = next(plugin for plugin in plugins_manager.plugins if plugin.name == "test_plugin_a") + plugin_a = next( + plugin for plugin in plugins_manager._get_plugins()[0] if plugin.name == "test_plugin_a" + ) assert plugin_a.external_views == [{"url_route": "/test_route"}] assert plugin_a.react_apps == [{"url_route": "/test_route_react_app"}] - assert len(plugins_manager.external_views) == 1 - assert len(plugins_manager.react_apps) == 1 + assert len(external_views) == 1 + assert len(react_apps) == 1 assert caplog.record_tuples == [ ( @@ -232,7 +239,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): ): from airflow import plugins_manager - plugins_manager.initialize_flask_plugins() + plugins_manager.get_flask_plugins() assert caplog.record_tuples == [] @@ -255,7 +262,7 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): ): from airflow import plugins_manager - plugins_manager.initialize_flask_plugins() + plugins_manager.get_flask_plugins() assert caplog.record_tuples == [] @@ -263,7 +270,7 @@ def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_metadata_dist """ Test that Airflow does not raise an error if there is any Exception because of a plugin. """ - from airflow.plugins_manager import import_errors, load_entrypoint_plugins + from airflow.plugins_manager import _load_entrypoint_plugins mock_dist = mock.Mock() mock_dist.metadata = {"Name": "test-dist"} @@ -279,14 +286,17 @@ def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_metadata_dist mock_metadata_distribution(return_value=[mock_dist]), caplog.at_level(logging.ERROR, logger="airflow.plugins_manager"), ): - load_entrypoint_plugins() + _, import_errors = _load_entrypoint_plugins() received_logs = caplog.text # Assert Traceback is shown too assert "Traceback (most recent call last):" in received_logs assert "my_fake_module not found" in received_logs assert "Failed to import plugin test-entrypoint" in received_logs - assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items() + assert ( + "test.plugins.test_plugins_manager", + "my_fake_module not found", + ) in import_errors.items() def test_registering_plugin_macros(self, request): """ @@ -326,15 +336,14 @@ class MacroPlugin(AirflowPlugin): # Verify that the symbol table in airflow.sdk.execution_time.macros has been updated with an entry for # this plugin, this is necessary in order to allow the plugin's macros to be used when # rendering templates. - assert hasattr(macros, MacroPlugin.name) + assert hasattr(macros, MacroPlugin.name or "") @skip_if_force_lowest_dependencies_marker def test_registering_plugin_listeners(self): from airflow import plugins_manager assert not get_listener_manager().has_listeners - with mock.patch("airflow.plugins_manager.plugins", []): - plugins_manager.load_plugins_from_plugin_directory() + with mock_plugin_manager(plugins=plugins_manager._load_plugins_from_plugin_directory()[0]): plugins_manager.integrate_listener_plugins(get_listener_manager()) assert get_listener_manager().has_listeners @@ -351,10 +360,9 @@ def test_registering_plugin_listeners(self): def test_should_import_plugin_from_providers(self): from airflow import plugins_manager - with mock.patch("airflow.plugins_manager.plugins", []): - assert len(plugins_manager.plugins) == 0 - plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) >= 2 + plugins, import_errors = plugins_manager._load_providers_plugins() + assert len(plugins) >= 2 + assert not import_errors @skip_if_force_lowest_dependencies_marker def test_does_not_double_import_entrypoint_provider_plugins(self): @@ -369,11 +377,10 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): mock_dist.version = "1.0.0" mock_dist.entry_points = [mock_entrypoint] - with mock.patch("airflow.plugins_manager.plugins", []): - assert len(plugins_manager.plugins) == 0 - plugins_manager.load_entrypoint_plugins() - plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) == 4 + # Mock/skip loading from plugin dir + with mock.patch("airflow.plugins_manager._load_plugins_from_plugin_directory", return_value=([], [])): + plugins = plugins_manager._get_plugins()[0] + assert len(plugins) == 4 class TestPluginsDirectorySource: @@ -400,7 +407,7 @@ def test_should_return_correct_source_details(self, mock_metadata_distribution): mock_dist.entry_points = [mock_entrypoint] with mock_metadata_distribution(return_value=[mock_dist]): - plugins_manager.load_entrypoint_plugins() + plugins_manager._load_entrypoint_plugins() source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist) assert str(mock_entrypoint) == source.entrypoint diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 7b87c6e00a89a..8e25e7663d12b 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -479,15 +479,16 @@ def serialize_subprocess(queue, dag_folder): @pytest.fixture -def timetable_plugin(monkeypatch): +def timetable_plugin(monkeypatch: pytest.MonkeyPatch): """Patch plugins manager to always and only return our custom timetable.""" from airflow import plugins_manager - monkeypatch.setattr(plugins_manager, "initialize_timetables_plugins", lambda: None) monkeypatch.setattr( plugins_manager, - "timetable_classes", - {"tests_common.test_utils.timetables.CustomSerializationTimetable": CustomSerializationTimetable}, + "get_timetables_plugins", + lambda: { + "tests_common.test_utils.timetables.CustomSerializationTimetable": CustomSerializationTimetable + }, ) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index e622fd6a1288e..29ef46a1fec79 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -217,7 +217,7 @@ def mock_plugins_manager_for_all_non_db_tests(): return from tests_common.test_utils.mock_plugins import mock_plugin_manager - with mock_plugin_manager() as _fixture: + with mock_plugin_manager(plugins=[]) as _fixture: yield _fixture diff --git a/devel-common/src/tests_common/test_utils/mock_plugins.py b/devel-common/src/tests_common/test_utils/mock_plugins.py index 36d65a8235854..dfc654d15ee94 100644 --- a/devel-common/src/tests_common/test_utils/mock_plugins.py +++ b/devel-common/src/tests_common/test_utils/mock_plugins.py @@ -19,9 +19,9 @@ from contextlib import ExitStack, contextmanager from unittest import mock -from airflow import __version__ as airflow_version +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS -PLUGINS_MANAGER_NULLABLE_ATTRIBUTES = [ +PLUGINS_MANAGER_NULLABLE_ATTRIBUTES_V3_0 = [ "plugins", "macros_modules", "admin_views", @@ -75,41 +75,71 @@ def mock_plugin_manager(plugins=None, **kwargs): Use this context if you want your test to not have side effects in airflow.plugins_manager, and other tests do not affect the results of this test. """ - illegal_arguments = set(kwargs.keys()) - set(PLUGINS_MANAGER_NULLABLE_ATTRIBUTES) - {"import_errors"} + illegal_arguments = set(kwargs.keys()) - set(PLUGINS_MANAGER_NULLABLE_ATTRIBUTES_V3_0) - {"import_errors"} if illegal_arguments: raise TypeError( f"TypeError: mock_plugin_manager got an unexpected keyword arguments: {illegal_arguments}" ) # Handle plugins specially with ExitStack() as exit_stack: + if AIRFLOW_V_3_2_PLUS: + # Always start the block with an non-initialized plugins, so ensure_plugins_loaded runs. + from airflow import plugins_manager + + plugins_manager._get_plugins.cache_clear() + plugins_manager._get_ui_plugins.cache_clear() + plugins_manager.get_flask_plugins.cache_clear() + plugins_manager.get_fastapi_plugins.cache_clear() + plugins_manager._get_extra_operators_links_plugins.cache_clear() + plugins_manager.get_timetables_plugins.cache_clear() + plugins_manager.get_hook_lineage_readers_plugins.cache_clear() + plugins_manager.integrate_macros_plugins.cache_clear() + plugins_manager.get_priority_weight_strategy_plugins.cache_clear() + + if plugins is not None or "import_errors" in kwargs: + exit_stack.enter_context( + mock.patch( + "airflow.plugins_manager._get_plugins", + return_value=( + plugins or [], + kwargs.get("import_errors", {}), + ), + ) + ) + elif kwargs: + raise NotImplementedError( + "mock_plugin_manager does not support patching other attributes in Airflow 3.2+" + ) + else: - def mock_loaded_plugins(): - exit_stack.enter_context(mock.patch("airflow.plugins_manager.plugins", plugins or [])) + def mock_loaded_plugins(): + exit_stack.enter_context(mock.patch("airflow.plugins_manager.plugins", plugins or [])) - exit_stack.enter_context( - mock.patch( - "airflow.plugins_manager.load_plugins_from_plugin_directory", side_effect=mock_loaded_plugins + exit_stack.enter_context( + mock.patch( + "airflow.plugins_manager.load_plugins_from_plugin_directory", + side_effect=mock_loaded_plugins, + ) + ) + exit_stack.enter_context( + mock.patch("airflow.plugins_manager.load_providers_plugins", side_effect=mock_loaded_plugins) + ) + exit_stack.enter_context( + mock.patch("airflow.plugins_manager.load_entrypoint_plugins", side_effect=mock_loaded_plugins) ) - ) - exit_stack.enter_context( - mock.patch("airflow.plugins_manager.load_providers_plugins", side_effect=mock_loaded_plugins) - ) - exit_stack.enter_context( - mock.patch("airflow.plugins_manager.load_entrypoint_plugins", side_effect=mock_loaded_plugins) - ) - if airflow_version <= "3": - ATTR_TO_PATCH = PLUGINS_MANAGER_NULLABLE_ATTRIBUTES_V2_10 - else: - ATTR_TO_PATCH = PLUGINS_MANAGER_NULLABLE_ATTRIBUTES + if AIRFLOW_V_3_0_PLUS: + ATTR_TO_PATCH = PLUGINS_MANAGER_NULLABLE_ATTRIBUTES_V3_0 + else: + ATTR_TO_PATCH = PLUGINS_MANAGER_NULLABLE_ATTRIBUTES_V2_10 - for attr in ATTR_TO_PATCH: - exit_stack.enter_context(mock.patch(f"airflow.plugins_manager.{attr}", kwargs.get(attr))) + for attr in ATTR_TO_PATCH: + exit_stack.enter_context(mock.patch(f"airflow.plugins_manager.{attr}", kwargs.get(attr))) - # Always start the block with an empty plugins, so ensure_plugins_loaded runs. - exit_stack.enter_context(mock.patch("airflow.plugins_manager.plugins", None)) - exit_stack.enter_context( - mock.patch("airflow.plugins_manager.import_errors", kwargs.get("import_errors", {})) - ) + # Always start the block with an non-initialized plugins, so ensure_plugins_loaded runs. + exit_stack.enter_context(mock.patch("airflow.plugins_manager.plugins", None)) + exit_stack.enter_context( + mock.patch("airflow.plugins_manager.import_errors", kwargs.get("import_errors", {})) + ) yield diff --git a/providers/fab/.pre-commit-config.yaml b/providers/fab/.pre-commit-config.yaml index 5dcd81e510721..7f05001e80ef4 100644 --- a/providers/fab/.pre-commit-config.yaml +++ b/providers/fab/.pre-commit-config.yaml @@ -27,7 +27,12 @@ repos: - id: compile-fab-assets name: Compile FAB provider assets language: node - files: ^.*/www/ + files: ^src/airflow/providers/fab/www/ + exclude: | + (?x) + ^src/airflow/providers/fab/www/api_connexion/.*| + ^src/airflow/providers/fab/www/extensions/.*| + ^src/airflow/providers/fab/www/security/.* entry: ../../scripts/ci/prek/compile_provider_assets.py fab pass_filenames: false additional_dependencies: ['yarn@1.22.21'] diff --git a/providers/fab/src/airflow/providers/fab/version_compat.py b/providers/fab/src/airflow/providers/fab/version_compat.py index e1d9559cc311b..350e25ce81b0c 100644 --- a/providers/fab/src/airflow/providers/fab/version_compat.py +++ b/providers/fab/src/airflow/providers/fab/version_compat.py @@ -34,3 +34,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_1_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 1) +AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py index e7623e8217ed3..f26b698161513 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py @@ -26,7 +26,7 @@ from flask import request from airflow.api_fastapi.app import get_auth_manager -from airflow.providers.fab.version_compat import AIRFLOW_V_3_1_PLUS +from airflow.providers.fab.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.fab.www.api_connexion.exceptions import common_error_handler if TYPE_CHECKING: @@ -102,11 +102,17 @@ def init_plugins(app): """Integrate Flask and FAB with plugins.""" from airflow import plugins_manager - plugins_manager.initialize_flask_plugins() + if AIRFLOW_V_3_2_PLUS: + blueprints, appbuilder_views, appbuilder_menu_links = plugins_manager.get_flask_plugins() + else: + plugins_manager.initialize_flask_plugins() # type: ignore + blueprints = plugins_manager.flask_blueprints # type: ignore + appbuilder_views = plugins_manager.flask_appbuilder_views # type: ignore + appbuilder_menu_links = plugins_manager.flask_appbuilder_menu_links # type: ignore appbuilder = app.appbuilder - for view in plugins_manager.flask_appbuilder_views: + for view in appbuilder_views: name = view.get("name") if name: filtered_view_kwargs = {k: v for k, v in view.items() if k not in ["view"]} @@ -124,13 +130,11 @@ def init_plugins(app): # Since Airflow 3.1 flask_appbuilder_menu_links are added to the Airflow 3 UI # navbar.. if not AIRFLOW_V_3_1_PLUS: - for menu_link in sorted( - plugins_manager.flask_appbuilder_menu_links, key=lambda x: (x.get("category", ""), x["name"]) - ): + for menu_link in sorted(appbuilder_menu_links, key=lambda x: (x.get("category", ""), x["name"])): log.debug("Adding menu link %s to %s", menu_link["name"], menu_link["href"]) appbuilder.add_link(**menu_link) - for blue_print in plugins_manager.flask_blueprints: + for blue_print in blueprints: log.debug("Adding blueprint %s:%s", blue_print["name"], blue_print["blueprint"].import_name) app.register_blueprint(blue_print["blueprint"]) diff --git a/providers/fab/www-hash.txt b/providers/fab/www-hash.txt index 144721dd69c03..965b4823c9628 100644 --- a/providers/fab/www-hash.txt +++ b/providers/fab/www-hash.txt @@ -1 +1 @@ -b1cb162c904247bab244b9fb163a340665e64301d6c62926ee583e82edf3023d +b8ebef1806aed0a26ed8fd468040a51eaec700520bd41dbbdbf8136a4663eb6e diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index 6ae329e5345db..eadb2e7d59411 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -166,15 +166,12 @@ def _disable_ol_plugin(): # And we load plugins when setting the priority_weight field import airflow.plugins_manager - old = airflow.plugins_manager.plugins - - assert old is None, "Plugins already loaded, too late to stop them being loaded!" - - airflow.plugins_manager.plugins = [] + old = airflow.plugins_manager._get_plugins + airflow.plugins_manager._get_plugins = lambda: ([], {}) yield - airflow.plugins_manager.plugins = None + airflow.plugins_manager._get_plugins = old @pytest.fixture(autouse=True)