diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b81f6f9a37e3e..33ae4e64b3206 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -615,7 +615,7 @@ repos: ^providers/google/src/airflow/providers/google/cloud/operators/cloud_build\.py$| ^providers/google/src/airflow/providers/google/cloud/operators/dataproc\.py$| ^providers/google/src/airflow/providers/google/cloud/operators/mlengine\.py$| - ^providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/definition.py| + ^providers/keycloak/src/airflow/providers/keycloak/cli/definition.py| ^providers/microsoft/azure/docs/connections/azure_cosmos\.rst$| ^providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/cosmos\.py$| ^providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm\.py$| diff --git a/airflow-core/docs/cli-and-env-variables-ref.rst b/airflow-core/docs/cli-and-env-variables-ref.rst index 9fe8bdb80d07a..309d17e43561c 100644 --- a/airflow-core/docs/cli-and-env-variables-ref.rst +++ b/airflow-core/docs/cli-and-env-variables-ref.rst @@ -37,6 +37,11 @@ development and testing. Providers that implement executors might contribute additional commands to the CLI. Here are the commands contributed by the community providers: + +.. important:: + Starting in Airflow ``3.2.0``, provider-level CLI commands are available to manage core extensions such as auth managers and executors. Implementing provider-level CLI commands can reduce CLI startup time by avoiding heavy imports when they are not required. + See :doc:`provider-level CLI ` for implementation guidance. + * Celery Executor and related CLI commands: :doc:`apache-airflow-providers-celery:cli-ref` * Kubernetes Executor and related CLI commands: :doc:`apache-airflow-providers-cncf-kubernetes:cli-ref` * Edge Executor and related CLI commands: :doc:`apache-airflow-providers-edge3:cli-ref` diff --git a/airflow-core/docs/core-concepts/auth-manager/index.rst b/airflow-core/docs/core-concepts/auth-manager/index.rst index 511aa308beb1f..71db38330303a 100644 --- a/airflow-core/docs/core-concepts/auth-manager/index.rst +++ b/airflow-core/docs/core-concepts/auth-manager/index.rst @@ -209,6 +209,10 @@ The following methods aren't required to override to have a functional Airflow a CLI ^^^ +.. important:: + Starting in Airflow ``3.2.0``, provider-level CLI commands are available to manage core extensions such as auth managers and executors. Implementing provider-level CLI commands can reduce CLI startup time by avoiding heavy imports when they are not required. + See :doc:`provider-level CLI ` for implementation guidance. + Auth managers may vend CLI commands which will be included in the ``airflow`` command line tool by implementing the ``get_cli_commands`` method. The commands can be used to setup required resources. Commands are only vended for the currently configured auth manager. A pseudo-code example of implementing CLI command vending from an auth manager can be seen below: .. code-block:: python diff --git a/airflow-core/docs/core-concepts/executor/index.rst b/airflow-core/docs/core-concepts/executor/index.rst index d27b306c3dc97..28699751360f5 100644 --- a/airflow-core/docs/core-concepts/executor/index.rst +++ b/airflow-core/docs/core-concepts/executor/index.rst @@ -317,6 +317,10 @@ The ``BaseExecutor`` class interface contains a set of attributes that Airflow c CLI ^^^ +.. important:: + Starting in Airflow ``3.2.0``, provider-level CLI commands are available to manage core extensions such as auth managers and executors. Implementing provider-level CLI commands can reduce CLI startup time by avoiding heavy imports when they are not required. + See :doc:`provider-level CLI ` for implementation guidance. + Executors may vend CLI commands which will be included in the ``airflow`` command line tool by implementing the ``get_cli_commands`` method. Executors such as ``CeleryExecutor`` and ``KubernetesExecutor`` for example, make use of this mechanism. The commands can be used to setup required workers, initialize environment or set other configuration. Commands are only vended for the currently configured executor. A pseudo-code example of implementing CLI command vending from an executor can be seen below: .. code-block:: python diff --git a/airflow-core/src/airflow/cli/cli_parser.py b/airflow-core/src/airflow/cli/cli_parser.py index 3a59218648c44..0cac7331eddb3 100644 --- a/airflow-core/src/airflow/cli/cli_parser.py +++ b/airflow-core/src/airflow/cli/cli_parser.py @@ -26,7 +26,7 @@ import argparse import logging -import sys +import os from argparse import Action from collections import Counter from collections.abc import Iterable @@ -36,7 +36,7 @@ import lazy_object_proxy from rich_argparse import RawTextRichHelpFormatter, RichHelpFormatter -from airflow.api_fastapi.app import get_auth_manager_cls +from airflow._shared.module_loading import import_string from airflow.cli.cli_config import ( DAG_CLI_DICT, ActionCommand, @@ -46,7 +46,7 @@ ) from airflow.cli.utils import CliConflictError from airflow.exceptions import AirflowException -from airflow.executors.executor_loader import ExecutorLoader +from airflow.providers_manager import ProvidersManager from airflow.utils.helpers import partition if TYPE_CHECKING: @@ -59,32 +59,117 @@ log = logging.getLogger(__name__) +# AIRFLOW_PACKAGE_NAME is set when generating docs and we don't want to load provider commands when generating airflow-core CLI docs +if not os.environ.get("AIRFLOW_PACKAGE_NAME", None): + providers_manager = ProvidersManager() + # Load CLI commands from providers + try: + for cli_function in providers_manager.cli_command_functions: + try: + airflow_commands.extend(cli_function()) + except Exception: + log.exception("Failed to load CLI commands from provider function: %s", cli_function.__name__) + log.error("Ensure all dependencies are met and try again.") + # Do not re-raise the exception since we want the CLI to still function for + # other commands. + except Exception as e: + log.warning("Failed to load CLI commands from providers: %s", e) + # do not re-raise for the same reason as above + + WARNING_TEMPLATE = """ +Please define the 'cli' section in the 'get_provider_info' for custom {component} to avoid this warning. +For community providers, please update to the version that support 'cli' section. +For more details, see https://airflow.apache.org/docs/apache-airflow-providers/core-extensions/cli-commands.html + +Providers with {component} missing 'cli' section in 'get_provider_info': {not_defined_cli_dict} + """ -for executor_name in ExecutorLoader.get_executor_names(validate_teams=False): + # compat loading for older providers that define get_cli_commands methods on Executors try: - executor, _ = ExecutorLoader.import_executor_cls(executor_name) - airflow_commands.extend(executor.get_cli_commands()) - except Exception: - log.exception("Failed to load CLI commands from executor: %s", executor_name) - log.error( - "Ensure all dependencies are met and try again. If using a Celery based executor install " - "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " - "7.4.0+ version of the CNCF provider" + # if there is any executor_provider not in cli_provider, we have to do compat loading + # we use without check to avoid actual loading in this check + executors_not_defined_cli = { + executor_name: executor_provider + for executor_name, executor_provider in providers_manager.executor_without_check + if executor_provider not in providers_manager.cli_command_providers + } + if executors_not_defined_cli: + log.warning( + WARNING_TEMPLATE.format( + component="executors", not_defined_cli_dict=str(executors_not_defined_cli) + ) + ) + from airflow.executors.executor_loader import ExecutorLoader + + for executor_name in ExecutorLoader.get_executor_names(validate_teams=False): + # Skip if the executor already has CLI commands defined via the 'cli' section in provider.yaml + if executor_name.module_path not in executors_not_defined_cli: + log.debug( + "Skipping loading for '%s' as it is defined in 'cli' section.", + executor_name.module_path, + ) + continue + + try: + executor, _ = ExecutorLoader.import_executor_cls(executor_name) + airflow_commands.extend(executor.get_cli_commands()) + except Exception: + log.exception("Failed to load CLI commands from executor: %s", executor_name) + log.error( + "Ensure all dependencies are met and try again. If using a Celery based executor install " + "a 3.3.0+ version of the Celery provider. If using a Kubernetes executor, install a " + "7.4.0+ version of the CNCF provider" + ) + # Do not re-raise the exception since we want the CLI to still function for + # other commands. + + except Exception as e: + log.warning( + "Failed to load CLI commands from executors that didn't define `get_cli_commands` in `.cli.definition`: %s", + e, ) - # Do not re-raise the exception since we want the CLI to still function for - # other commands. -try: - auth_mgr = get_auth_manager_cls() - airflow_commands.extend(auth_mgr.get_cli_commands()) -except Exception as e: - log.warning("cannot load CLI commands from auth manager: %s", e) - log.warning("Auth manager is not configured and api-server will not be able to start.") - # do not re-raise for the same reason as above - if len(sys.argv) > 1 and sys.argv[1] == "api-server": - log.exception(e) - sys.exit(1) + # compat loading for older providers that define get_cli_commands methods on AuthManagers + try: + # if there is any auth_manager not in cli_provider, we have to do compat loading + # we use without check to avoid actual loading in this check + auth_managers_not_defined_cli = { + auth_manager_name: auth_manager_provider + for auth_manager_name, auth_manager_provider in providers_manager.auth_manager_without_check + if auth_manager_provider not in providers_manager.cli_command_providers + } + if auth_managers_not_defined_cli: + log.warning( + WARNING_TEMPLATE.format( + component="auth manager", not_defined_cli_dict=str(auth_managers_not_defined_cli) + ) + ) + from airflow.configuration import conf + from airflow.exceptions import AirflowConfigException + + auth_manager_cls_path = conf.get(section="core", key="auth_manager") + + if not auth_manager_cls_path: + raise AirflowConfigException( + "No auth manager defined in the config. Please specify one using section/key [core/auth_manager]." + ) + + if auth_manager_cls_path in auth_managers_not_defined_cli: + try: + auth_manager_cls = import_string(auth_manager_cls_path) + auth_manager = auth_manager_cls() + airflow_commands.extend(auth_manager.get_cli_commands()) + except Exception: + log.exception("Failed to load CLI commands from auth manager: %s", auth_manager_cls) + log.error("Ensure all dependencies are met and try again.") + # Do not re-raise the exception since we want the CLI to still function for + # other commands. + except Exception as e: + log.warning( + "Failed to load CLI commands from auth managers that didn't define `get_cli_commands` in `.cli.definition`: %s", + e, + ) ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands} @@ -94,7 +179,7 @@ dup = {k for k, v in Counter([c.name for c in airflow_commands]).items() if v > 1} raise CliConflictError( f"The following CLI {len(dup)} command(s) are defined more than once: {sorted(dup)}\n" - f"This can be due to an Executor or Auth Manager redefining core airflow CLI commands." + f"This can be due to a Provider redefining core airflow CLI commands." ) diff --git a/airflow-core/src/airflow/provider.yaml.schema.json b/airflow-core/src/airflow/provider.yaml.schema.json index c35e0d9de25e7..79f071cd74888 100644 --- a/airflow-core/src/airflow/provider.yaml.schema.json +++ b/airflow-core/src/airflow/provider.yaml.schema.json @@ -419,6 +419,13 @@ "type": "string" } }, + "cli": { + "type": "array", + "description": "CLI command functions exposed by the provider", + "items": { + "type": "string" + } + }, "config": { "type": "object", "additionalProperties": { diff --git a/airflow-core/src/airflow/provider_info.schema.json b/airflow-core/src/airflow/provider_info.schema.json index 3ca9756dfb2f6..3e408b07536f1 100644 --- a/airflow-core/src/airflow/provider_info.schema.json +++ b/airflow-core/src/airflow/provider_info.schema.json @@ -355,6 +355,13 @@ "type": "string" } }, + "cli": { + "type": "array", + "description": "CLI command functions exposed by the provider", + "items": { + "type": "string" + } + }, "config": { "type": "object", "additionalProperties": { diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 227e063d5df22..074a47c58111a 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -31,7 +31,7 @@ from functools import wraps from importlib.resources import files as resource_files from time import perf_counter -from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, TypeVar, cast from packaging.utils import canonicalize_name @@ -40,6 +40,9 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.singleton import Singleton +if TYPE_CHECKING: + from airflow.cli.cli_config import CLICommand + log = logging.getLogger(__name__) @@ -405,11 +408,15 @@ def __init__(self): self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {} # Customizations for javascript fields are kept here self._field_behaviours: dict[str, dict] = {} + self._cli_command_functions_set: set[Callable[[], list[CLICommand]]] = set() + self._cli_command_provider_name_set: set[str] = set() self._extra_link_class_name_set: set[str] = set() self._logging_class_name_set: set[str] = set() self._auth_manager_class_name_set: set[str] = set() + self._auth_manager_without_check_set: set[tuple[str, str]] = set() self._secrets_backend_class_name_set: set[str] = set() self._executor_class_name_set: set[str] = set() + self._executor_without_check_set: set[tuple[str, str]] = set() self._queue_class_name_set: set[str] = set() self._provider_configs: dict[str, dict[str, Any]] = {} self._trigger_info_set: set[TriggerInfo] = set() @@ -525,7 +532,13 @@ def initialize_providers_secrets_backends(self): def initialize_providers_executors(self): """Lazy initialization of providers executors information.""" self.initialize_providers_list() - self._discover_executors() + self._discover_executors(check=True) + + @provider_info_cache("executors_without_check") + def initialize_providers_executors_without_check(self): + """Lazy initialization of providers executors information.""" + self.initialize_providers_list() + self._discover_executors(check=False) @provider_info_cache("queues") def initialize_providers_queues(self): @@ -541,9 +554,15 @@ def initialize_providers_notifications(self): @provider_info_cache("auth_managers") def initialize_providers_auth_managers(self): - """Lazy initialization of providers notifications information.""" + """Lazy initialization of providers auth manager information.""" + self.initialize_providers_list() + self._discover_auth_managers(check=True) + + @provider_info_cache("auth_managers_without_check") + def initialize_providers_auth_managers_without_check(self): + """Lazy initialization of providers auth manager information.""" self.initialize_providers_list() - self._discover_auth_managers() + self._discover_auth_managers(check=False) @provider_info_cache("config") def initialize_providers_configuration(self): @@ -573,6 +592,12 @@ def initialize_providers_plugins(self): self.initialize_providers_list() self._discover_plugins() + @provider_info_cache("cli_command") + def initialize_providers_cli_command(self): + """Lazy initialization of providers CLI commands.""" + self.initialize_providers_list() + self._discover_cli_command() + def _discover_all_providers_from_packages(self) -> None: """ Discover all providers by scanning packages installed. @@ -1060,14 +1085,28 @@ def _add_customized_fields(self, package_name: str, hook_class: type, customized e, ) - def _discover_auth_managers(self) -> None: + def _discover_auth_managers(self, *, check: bool) -> None: """Retrieve all auth managers defined in the providers.""" for provider_package, provider in self._provider_dict.items(): if provider.data.get("auth-managers"): for auth_manager_class_name in provider.data["auth-managers"]: - if _correctness_check(provider_package, auth_manager_class_name, provider): + if not check: + self._auth_manager_without_check_set.add((auth_manager_class_name, provider_package)) + elif _correctness_check(provider_package, auth_manager_class_name, provider): self._auth_manager_class_name_set.add(auth_manager_class_name) + def _discover_cli_command(self) -> None: + """Retrieve all CLI command functions defined in the providers.""" + for provider_package, provider in self._provider_dict.items(): + if provider.data.get("cli"): + for cli_command_function_name in provider.data["cli"]: + # _correctness_check will return the function if found and correct + # we store the function itself instead of its name to avoid importing it again later in cli_parser to speed up cli loading + if cli_func := _correctness_check(provider_package, cli_command_function_name, provider): + cli_func = cast("Callable[[], list[CLICommand]]", cli_func) + self._cli_command_functions_set.add(cli_func) + self._cli_command_provider_name_set.add(provider_package) + def _discover_notifications(self) -> None: """Retrieve all notifications defined in the providers.""" for provider_package, provider in self._provider_dict.items(): @@ -1100,13 +1139,15 @@ def _discover_secrets_backends(self) -> None: if _correctness_check(provider_package, secrets_backends_class_name, provider): self._secrets_backend_class_name_set.add(secrets_backends_class_name) - def _discover_executors(self) -> None: + def _discover_executors(self, *, check: bool) -> None: """Retrieve all executors defined in the providers.""" for provider_package, provider in self._provider_dict.items(): if provider.data.get("executors"): - for executors_class_name in provider.data["executors"]: - if _correctness_check(provider_package, executors_class_name, provider): - self._executor_class_name_set.add(executors_class_name) + for executors_class_path in provider.data["executors"]: + if not check: + self._executor_without_check_set.add((executors_class_path, provider_package)) + elif _correctness_check(provider_package, executors_class_path, provider): + self._executor_class_name_set.add(executors_class_path) def _discover_queues(self) -> None: """Retrieve all queues defined in the providers.""" @@ -1158,6 +1199,24 @@ def auth_managers(self) -> list[str]: self.initialize_providers_auth_managers() return sorted(self._auth_manager_class_name_set) + @property + def auth_manager_without_check(self) -> set[tuple[str, str]]: + """Returns set of (auth manager class names, provider package name) without correctness check.""" + self.initialize_providers_auth_managers_without_check() + return self._auth_manager_without_check_set + + @property + def cli_command_functions(self) -> set[Callable[[], list[CLICommand]]]: + """Returns list of CLI command function names from providers.""" + self.initialize_providers_cli_command() + return self._cli_command_functions_set + + @property + def cli_command_providers(self) -> set[str]: + """Returns set of provider package names that provide CLI commands.""" + self.initialize_providers_cli_command() + return self._cli_command_provider_name_set + @property def notification(self) -> list[NotificationInfo]: """Returns information about available providers notifications class.""" @@ -1246,6 +1305,12 @@ def executor_class_names(self) -> list[str]: self.initialize_providers_executors() return sorted(self._executor_class_name_set) + @property + def executor_without_check(self) -> set[tuple[str, str]]: + """Returns set of (executor class names, provider package name) without correctness check.""" + self.initialize_providers_executors_without_check() + return self._executor_without_check_set + @property def queue_class_names(self) -> list[str]: self.initialize_providers_queues() @@ -1295,13 +1360,17 @@ def _cleanup(self): self._extra_link_class_name_set.clear() self._logging_class_name_set.clear() self._auth_manager_class_name_set.clear() + self._auth_manager_without_check_set.clear() self._secrets_backend_class_name_set.clear() self._executor_class_name_set.clear() + self._executor_without_check_set.clear() self._queue_class_name_set.clear() self._provider_configs.clear() self._trigger_info_set.clear() self._notification_info_set.clear() self._plugins_set.clear() + self._cli_command_functions_set.clear() + self._cli_command_provider_name_set.clear() self._initialized = False self._initialization_stack_trace = None diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index 01c1ded00cc59..e0d686a565bb7 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -21,6 +21,8 @@ import logging import re import sys +from collections.abc import Callable +from typing import TYPE_CHECKING PY313 = sys.version_info >= (3, 13) import warnings @@ -41,6 +43,11 @@ from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker, skip_if_not_on_main from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH +if TYPE_CHECKING: + from unittest.mock import MagicMock + + from airflow.cli.cli_config import CLICommand + def test_cleanup_providers_manager(cleanup_providers_manager): """Check the cleanup provider manager functionality.""" @@ -328,6 +335,40 @@ def test_auth_managers(self): auth_manager_class_names = list(provider_manager.auth_managers) assert len(auth_manager_class_names) > 0 + def test_cli(self): + provider_manager = ProvidersManager() + + # assert cli_command_functions is set of Callable[[], list[CLICommand]] + assert isinstance(provider_manager.cli_command_functions, set) + assert all(callable(func) for func in provider_manager.cli_command_functions) + # assert cli_command_providers is set of str + assert isinstance(provider_manager.cli_command_providers, set) + assert all(isinstance(provider, str) for provider in provider_manager.cli_command_providers) + + sorted_cli_command_functions: list[Callable[[], list[CLICommand]]] = sorted( + provider_manager.cli_command_functions, key=lambda x: x.__module__ + ) + sorted_cli_command_providers: list[str] = sorted(provider_manager.cli_command_providers) + + expected_functions_modules = [ + "airflow.providers.amazon.aws.cli.definition", + "airflow.providers.celery.cli.definition", + "airflow.providers.cncf.kubernetes.cli.definition", + "airflow.providers.edge3.cli.definition", + "airflow.providers.fab.cli.definition", + "airflow.providers.keycloak.cli.definition", + ] + expected_providers = [ + "apache-airflow-providers-amazon", + "apache-airflow-providers-celery", + "apache-airflow-providers-cncf-kubernetes", + "apache-airflow-providers-edge3", + "apache-airflow-providers-fab", + "apache-airflow-providers-keycloak", + ] + assert [func.__module__ for func in sorted_cli_command_functions] == expected_functions_modules + assert sorted_cli_command_providers == expected_providers + def test_dialects(self): provider_manager = ProvidersManager() dialect_class_names = list(provider_manager.dialects) @@ -363,6 +404,45 @@ def test_optional_feature_debug(self, mock_importlib_import_string): ] +class TestWithoutCheckProviderManager: + @patch("airflow.providers_manager.import_string") + @patch("airflow.providers_manager._correctness_check") + @patch("airflow.providers_manager.ProvidersManager._discover_auth_managers") + def test_auth_manager_without_check_property_should_not_called_import_string( + self, + mock_discover_auth_managers: MagicMock, + mock_correctness_check: MagicMock, + mock_importlib_import_string: MagicMock, + ): + providers_manager = ProvidersManager() + result = providers_manager.auth_manager_without_check + + mock_discover_auth_managers.assert_called_once_with(check=False) + mock_importlib_import_string.assert_not_called() + mock_correctness_check.assert_not_called() + + assert providers_manager._auth_manager_without_check_set == result + + @patch("airflow.providers_manager.import_string") + @patch("airflow.providers_manager._correctness_check") + @patch("airflow.providers_manager.ProvidersManager._discover_executors") + def test_executors_without_check_property_should_not_called_import_string( + self, + mock_discover_executors: MagicMock, + mock_correctness_check: MagicMock, + mock_importlib_import_string: MagicMock, + ): + providers_manager = ProvidersManager() + providers_manager.executor_without_check + result = providers_manager.auth_manager_without_check + + mock_discover_executors.assert_called_once_with(check=False) + mock_importlib_import_string.assert_not_called() + mock_correctness_check.assert_not_called() + + assert providers_manager._executor_without_check_set == result + + @pytest.mark.parametrize( ("value", "expected_outputs"), [ diff --git a/airflow-core/tests/unit/cli/test_cli_parser.py b/airflow-core/tests/unit/cli/test_cli_parser.py index fed8b36f67bc6..56d2c9e6a9558 100644 --- a/airflow-core/tests/unit/cli/test_cli_parser.py +++ b/airflow-core/tests/unit/cli/test_cli_parser.py @@ -20,12 +20,14 @@ import argparse import contextlib +import logging import os import re import subprocess import sys import timeit from collections import Counter +from collections.abc import Callable from importlib import reload from io import StringIO from pathlib import Path @@ -39,10 +41,6 @@ from airflow.cli.utils import CliConflictError from airflow.configuration import AIRFLOW_HOME from airflow.executors import executor_loader -from airflow.executors.executor_utils import ExecutorName -from airflow.executors.local_executor import LocalExecutor -from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor -from airflow.providers.celery.executors.celery_executor import CeleryExecutor from tests_common.test_utils.config import conf_vars @@ -143,133 +141,297 @@ def test_subcommand_arg_flag_conflict(self): f"short option flags {conflict_short_option}" ) - @pytest.mark.db_test - @patch.object(LocalExecutor, "get_cli_commands") - def test_dynamic_conflict_detection(self, cli_commands_mock: MagicMock): - core_commands.append( + @staticmethod + def mock_duplicate_command(): + return [ ActionCommand( name="test_command", help="does nothing", func=lambda: None, args=[], - ) - ) - cli_commands_mock.return_value = [ + ), ActionCommand( name="test_command", - help="just a command that'll conflict with one defined in core", + help="just a command that'll conflict with the other one", func=lambda: None, args=[], - ) + ), ] - reload(executor_loader) - with pytest.raises(CliConflictError, match="test_command"): - # force re-evaluation of cli commands (done in top level code) - reload(cli_parser) - @patch.object(CeleryExecutor, "get_cli_commands") - @patch.object(AwsEcsExecutor, "get_cli_commands") - def test_hybrid_executor_get_cli_commands( - self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock - ): - """Test that if multiple executors are configured, then every executor loads its commands.""" - ecs_executor_command = ActionCommand( - name="ecs_command", - help="test command for ecs executor", - func=lambda: None, - args=[], - ) - ecs_executor_cli_commands_mock.return_value = [ecs_executor_command] + @patch( + "airflow.providers_manager.ProvidersManager.cli_command_functions", + new_callable=mock.PropertyMock, + ) + def test_dynamic_conflict_detection(self, mock_cli_command_functions: MagicMock): + mock_cli_command_functions.return_value = [self.mock_duplicate_command] - celery_executor_command = ActionCommand( - name="celery_command", - help="test command for celery executor", + test_command = ActionCommand( + name="test_command", + help="does nothing", func=lambda: None, args=[], ) - celery_executor_cli_commands_mock.return_value = [celery_executor_command] - reload(executor_loader) - executor_loader.ExecutorLoader.get_executor_names = mock.Mock( - return_value=[ - ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"), - ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), - ] - ) + core_commands.append(test_command) - reload(cli_parser) - commands = [command.name for command in cli_parser.airflow_commands] - assert celery_executor_command.name in commands - assert ecs_executor_command.name in commands + with pytest.raises(CliConflictError, match="test_command"): + # force re-evaluation of cli commands (done in top level code) + reload(cli_parser) - @patch.object(CeleryExecutor, "get_cli_commands") - @patch.object(AwsEcsExecutor, "get_cli_commands") - def test_hybrid_executor_get_cli_commands_with_error( - self, ecs_executor_cli_commands_mock, celery_executor_cli_commands_mock, caplog - ): - """ - Test that if multiple executors are configured, then every executor loads its commands. - If the executor fails to load its commands, the CLI should log the error, and continue loading - """ - caplog.set_level("ERROR") - ecs_executor_command = ActionCommand( - name="ecs_command", - help="test command for ecs executor", - func=lambda: None, - args=[], - ) - ecs_executor_cli_commands_mock.side_effect = Exception() + @pytest.mark.parametrize( + "module_pattern", + ["airflow.auth.managers", "airflow.executors.executor_loader"], + ) + def test_should_not_import_in_cli_parser(self, module_pattern: str): + """Test that cli_parser does not import auth_managers or executor_loader at import time.""" + # Remove the module from sys.modules if present to force a fresh import + import sys - celery_executor_command = ActionCommand( - name="celery_command", - help="test command for celery executor", - func=lambda: None, - args=[], - ) - celery_executor_cli_commands_mock.return_value = [celery_executor_command] - reload(executor_loader) - executor_loader.ExecutorLoader.get_executor_names = mock.Mock( - return_value=[ - ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), - ExecutorName("airflow.providers.celery.executors.celery_executor.CeleryExecutor"), - ] - ) + modules_to_remove = [mod for mod in sys.modules.keys() if module_pattern in mod] + removed_modules = {} + for mod in modules_to_remove: + removed_modules[mod] = sys.modules.pop(mod) - reload(cli_parser) - commands = [command.name for command in cli_parser.airflow_commands] - assert celery_executor_command.name in commands - assert ecs_executor_command.name not in commands - assert ( - "Failed to load CLI commands from executor: ::airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" - in caplog.messages[0] - ) + try: + reload(cli_parser) + # Check that the module pattern is not in sys.modules after reload + loaded_modules = list(sys.modules.keys()) + matching_modules = [mod for mod in loaded_modules if module_pattern in mod] + assert not matching_modules, ( + f"Module pattern '{module_pattern}' found in sys.modules: {matching_modules}" + ) + finally: + # Restore removed modules + sys.modules.update(removed_modules) - @patch.object(AwsEcsExecutor, "get_cli_commands") - def test_cli_parser_fail_to_load_executor(self, ecs_executor_cli_commands_mock, caplog): - caplog.set_level("ERROR") + @pytest.mark.parametrize( + ( + "cli_command_functions", + "cli_command_providers", + "executor_without_check", + "expected_loaded_executors", + ), + [ + pytest.param( + [], + set(), + { + ("path.to.KubernetesExecutor", "apache-airflow-providers-cncf-kubernetes"), + }, + ["path.to.KubernetesExecutor"], + id="empty cli section should load all the executors by ExecutorLoader", + ), + pytest.param( + [lambda: [ActionCommand(name="celery", help="", func=lambda: None, args=[])]], + {"apache-airflow-providers-celery"}, + { + ("path.to.CeleryExecutor", "apache-airflow-providers-celery"), + ("path.to.KubernetesExecutor", "apache-airflow-providers-cncf-kubernetes"), + }, + ["path.to.KubernetesExecutor"], + id="only partial executor define cli section in provider info, should load the rest by ExecutorLoader", + ), + pytest.param( + [ + lambda: [ActionCommand(name="celery", help="", func=lambda: None, args=[])], + lambda: [ActionCommand(name="kubernetes", help="", func=lambda: None, args=[])], + ], + {"apache-airflow-providers-celery", "apache-airflow-providers-cncf-kubernetes"}, + { + ("path.to.CeleryExecutor", "apache-airflow-providers-celery"), + ("path.to.KubernetesExecutor", "apache-airflow-providers-cncf-kubernetes"), + }, + [], + id="all executors define cli section in provider info, should not load any by ExecutorLoader", + ), + ], + ) + @patch("airflow.executors.executor_loader.ExecutorLoader.import_executor_cls") + @patch( + "airflow.executors.executor_loader.ExecutorLoader.get_executor_names", + ) + @patch( + "airflow.providers_manager.ProvidersManager.executor_without_check", + new_callable=mock.PropertyMock, + ) + @patch( + "airflow.providers_manager.ProvidersManager.cli_command_providers", + new_callable=mock.PropertyMock, + ) + @patch( + "airflow.providers_manager.ProvidersManager.cli_command_functions", + new_callable=mock.PropertyMock, + ) + def test_compat_cli_loading_for_executors_commands( + self, + mock_cli_command_functions: MagicMock, + mock_cli_command_providers: MagicMock, + mock_executor_without_check: MagicMock, + mock_get_executor_names: MagicMock, + mock_import_executor_cls: MagicMock, + cli_command_functions: list[Callable[[], list[ActionCommand | cli_parser.GroupCommand]]], + cli_command_providers: set[str], + executor_without_check: set[tuple[str, str]], + expected_loaded_executors: list[str], + caplog, + ): + # Create mock ExecutorName objects + mock_executor_names = [ + MagicMock(name=executor_name.split(".")[-1], module_path=executor_name) + for executor_name, _ in executor_without_check + ] - ecs_executor_command = ActionCommand( - name="ecs_command", - help="test command for ecs executor", - func=lambda: None, - args=[], - ) - ecs_executor_cli_commands_mock.return_value = [ecs_executor_command] + # Create mock executor classes that return empty command lists + mock_executor_instance = MagicMock() + mock_executor_instance.get_cli_commands.return_value = [] + mock_import_executor_cls.return_value = (mock_executor_instance, None) + + # mock + mock_cli_command_functions.return_value = cli_command_functions + mock_cli_command_providers.return_value = cli_command_providers + mock_executor_without_check.return_value = executor_without_check + mock_get_executor_names.return_value = mock_executor_names + + # act + with caplog.at_level(logging.WARNING, logger="airflow.cli.cli_parser"): + reload(cli_parser) - reload(executor_loader) - executor_loader.ExecutorLoader.get_executor_names = mock.Mock( - return_value=[ - ExecutorName("airflow.providers.incorrect.executor.Executor"), - ExecutorName("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"), + # assert + expected_warning = "Please define the 'cli' section in the 'get_provider_info' for custom executors to avoid this warning." + if expected_loaded_executors: + assert expected_warning in caplog.text + for executor_path in expected_loaded_executors: + assert executor_path in caplog.text + else: + assert expected_warning not in caplog.text + + # Verify import_executor_cls was called with correct ExecutorName objects + if expected_loaded_executors: + expected_calls = [ + mock.call(executor_name) + for executor_name in mock_executor_names + if executor_name.module_path in expected_loaded_executors ] - ) + mock_import_executor_cls.assert_has_calls(expected_calls, any_order=True) + else: + mock_import_executor_cls.assert_not_called() - reload(cli_parser) - commands = [command.name for command in cli_parser.airflow_commands] - assert ecs_executor_command.name in commands - assert ( - "Failed to load CLI commands from executor: ::airflow.providers.incorrect.executor.Executor" - in caplog.messages[0] - ) + @pytest.mark.parametrize( + ( + "cli_command_functions", + "cli_command_providers", + "auth_manager_without_check", + "auth_manager_cls_path", + "expected_loaded", + ), + [ + pytest.param( + [], + set(), + { + ( + "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + "apache-airflow-providers-fab", + ), + }, + "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + True, + id="empty cli section should load auth manager", + ), + pytest.param( + [lambda: [ActionCommand(name="fab", help="", func=lambda: None, args=[])]], + {"apache-airflow-providers-fab"}, + { + ( + "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + "apache-airflow-providers-fab", + ), + }, + "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + False, + id="auth manager with cli section should not load by import_string", + ), + pytest.param( + [], + set(), + { + ( + "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + "apache-airflow-providers-amazon", + ), + ( + "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + "apache-airflow-providers-fab", + ), + }, + "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + True, + id="only configured auth manager should be loaded", + ), + ], + ) + @patch( + "airflow.providers_manager.ProvidersManager.cli_command_functions", + new_callable=mock.PropertyMock, + ) + @patch( + "airflow.providers_manager.ProvidersManager.cli_command_providers", + new_callable=mock.PropertyMock, + ) + @patch( + "airflow.providers_manager.ProvidersManager.auth_manager_without_check", + new_callable=mock.PropertyMock, + ) + @patch("airflow.configuration.conf.get") + @patch("airflow._shared.module_loading.import_string") + def test_compat_cli_loading_for_auth_manager_commands( + self, + mock_import_string: MagicMock, + mock_conf_get: MagicMock, + mock_auth_manager_without_check: MagicMock, + mock_cli_command_providers: MagicMock, + mock_cli_command_functions: MagicMock, + cli_command_functions: list[Callable[[], list[ActionCommand | cli_parser.GroupCommand]]], + cli_command_providers: set[str], + auth_manager_without_check: set[tuple[str, str]], + auth_manager_cls_path: str, + expected_loaded: bool, + caplog, + ): + # Create mock auth manager instance that returns empty command lists + mock_auth_manager_instance = MagicMock() + mock_auth_manager_instance.get_cli_commands.return_value = [] + mock_auth_manager_cls = MagicMock(return_value=mock_auth_manager_instance) + mock_import_string.return_value = mock_auth_manager_cls + + # Mock configuration + mock_conf_get.return_value = auth_manager_cls_path + + # mock providers manager + mock_cli_command_functions.return_value = cli_command_functions + mock_cli_command_providers.return_value = cli_command_providers + mock_auth_manager_without_check.return_value = auth_manager_without_check + + # act + with caplog.at_level(logging.WARNING, logger="airflow.cli.cli_parser"): + reload(cli_parser) + + # assert + expected_warning = "Please define the 'cli' section in the 'get_provider_info' for custom auth manager to avoid this warning." + if expected_loaded: + assert expected_warning in caplog.text + assert auth_manager_cls_path in caplog.text + mock_import_string.assert_called_once_with(auth_manager_cls_path) + mock_auth_manager_cls.assert_called_once() + mock_auth_manager_instance.get_cli_commands.assert_called_once() + else: + if auth_manager_cls_path in [path for path, _ in auth_manager_without_check]: + # Auth manager is in the without_check but also in cli_providers, so warning should appear + # but import_string should NOT be called + assert expected_warning not in caplog.text + mock_import_string.assert_not_called() + else: + # Auth manager is not in the without_check, no warning + mock_import_string.assert_not_called() def test_falsy_default_value(self): arg = cli_config.Arg(("--test",), default=0, type=int) @@ -368,27 +530,6 @@ def test_variables_import_help_message_consistency(self): f"Please update ARG_VAR_IMPORT help message in cli_config.py to include: {', '.join([f'.{fmt}' for fmt in sorted(missing_in_help)])}" ) - @pytest.mark.parametrize( - "command", - [ - "celery", - "kubernetes", - ], - ) - def test_executor_specific_commands_not_accessible(self, command): - with ( - contextlib.redirect_stderr(StringIO()) as stderr, - ): - reload(executor_loader) - reload(cli_parser) - parser = cli_parser.get_parser() - with pytest.raises(SystemExit): - parser.parse_args([command]) - stderr_val = stderr.getvalue() - assert ( - f"airflow command error: argument GROUP_OR_COMMAND: invalid choice: '{command}'" - ) in stderr_val - @pytest.mark.parametrize( ("executor", "expected_args"), [ @@ -493,24 +634,6 @@ def test_cli_run_time(self): # Average run time of Airflow CLI should at least be within 3.5s assert timing_result < threshold - def test_cli_parsing_does_not_initialize_providers_manager(self): - """ - Test that CLI parsing does not initialize providers manager. - - This test is here to make sure that we do not initialize providers manager - it is run as a - separate subprocess, to make sure we do not have providers manager initialized in the main - process from other tests. - """ - CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) - CONFIG_FILE.touch(exist_ok=True) - result = subprocess.run( - [sys.executable, "-m", "airflow", "providers", "lazy-loaded"], - env={"PYTHONPATH": os.pathsep.join(sys.path)}, - check=False, - text=True, - ) - assert result.returncode == 0 - def test_airflow_config_contains_providers(self): """ Test that airflow config has providers included by default. @@ -558,13 +681,3 @@ def test_airflow_config_output_does_not_contain_providers_when_excluded(self): ) assert result.returncode == 0 assert "celery_config_options" not in result.stdout - - def test_cli_parser_skips_team_validation(self): - """Test that CLI parser calls get_executor_names with validate_teams=False to prevent database dependency during CLI loading.""" - with patch.object(executor_loader.ExecutorLoader, "get_executor_names") as mock_get_executor_names: - mock_get_executor_names.return_value = [] - # Force reload of cli_parser to trigger the executor loading - reload(cli_parser) - - # Verify get_executor_names was called with validate_teams=False - mock_get_executor_names.assert_called_with(validate_teams=False) diff --git a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py index 489aebf70fdf4..d2c86f6de9827 100644 --- a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py +++ b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py @@ -532,6 +532,22 @@ def render_content( ) +class CliCommandsDirective(BaseJinjaReferenceDirective): + """Generate list of CLI commands""" + + def render_content( + self, *, tags: set[str] | None, header_separator: str = DEFAULT_HEADER_SEPARATOR + ) -> str: + tabular_data = [ + (provider["name"], provider["package-name"]) + for provider in load_package_data() + if provider.get("cli") is not None + ] + return _render_template( + "configuration.rst.jinja2", items=tabular_data, header_separator=header_separator + ) + + def setup(app): """Setup plugin""" app.add_directive("operators-hooks-ref", OperatorsHooksReferenceDirective) @@ -548,6 +564,7 @@ def setup(app): app.add_directive("airflow-deprecations", DeprecationsDirective) app.add_directive("airflow-dataset-schemes", AssetSchemeDirective) app.add_directive("airflow-auth-managers", AuthManagersDirective) + app.add_directive("airflow-cli-commands", CliCommandsDirective) return {"parallel_read_safe": True, "parallel_write_safe": True} diff --git a/devel-common/src/sphinx_exts/templates/cli-commands.rst.jinja2 b/devel-common/src/sphinx_exts/templates/cli-commands.rst.jinja2 new file mode 100644 index 0000000000000..5cd58c2284f27 --- /dev/null +++ b/devel-common/src/sphinx_exts/templates/cli-commands.rst.jinja2 @@ -0,0 +1,22 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} + +{%for name, provider_package in items %} +* :doc:`CLI for {{ name }} ({{ provider_package }})<{{ provider_package }}:cli-ref>` +{% endfor %} diff --git a/providers-summary-docs/core-extensions/cli-commands.rst b/providers-summary-docs/core-extensions/cli-commands.rst new file mode 100644 index 0000000000000..e4712aad9d224 --- /dev/null +++ b/providers-summary-docs/core-extensions/cli-commands.rst @@ -0,0 +1,143 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Command Line Interface +---------------------- + +.. important:: + The Airflow Core version must be ``3.2.0`` or newer to be able to use provider-level CLI commands. + + +If your provider include :doc:`/core-extensions/auth-managers` or :doc:`/core-extensions/executors`, you should also implement the provider-level CLI commands +to improve the Airflow CLI response speed and avoid loading heavy dependencies when those commands are not needed. + +Even if your auth manager or executor do not implement the ``get_cli_commands`` interface (:meth:`airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_cli_commands` or :meth:`airflow.executors.base_executor.BaseExecutor.get_cli_commands`), you should still implement provider-level CLI commands that return empty list to avoid loading auth manager or executor code for every CLI command, which will also resolve the following warning: + +.. code-block:: console + + Please define the 'cli' section in the provider.yaml for custom auth managers to avoid this warning. + For community providers, please update to the version that support '.cli.definition.get_cli_commands' function. + For more details, see https://airflow.apache.org/docs/apache-airflow-providers/core-extensions/cli-commands.html + + Please define the 'cli' section in the provider.yaml for custom executors to avoid this warning. + For community providers, please update to the version that support '.cli.definition.get_cli_commands' function. + For more details, see https://airflow.apache.org/docs/apache-airflow-providers/core-extensions/cli-commands.html + + +Implementing provider-level CLI Commands +======================================== + +To implement provider-level CLI commands, follow these steps: + +1. Define all your CLI commands in ``airflow.providers..cli.definition`` module. Additionally, you **should avoid defining heavy dependencies in this module** to reduce the Airflow CLI startup time. Please use ``airflow.cli.cli_config.lazy_load_command`` utility to lazily load the actual callable to run. + +.. code-block:: python + + from airflow.cli.cli_config import ( + ActionCommand, + Arg, + GroupCommand, + lazy_load_command, + ) + + + @staticmethod + def get_my_cli_commands() -> list[GroupCommand]: + executor_sub_commands = [ + ActionCommand( + name="executor_subcommand_name", + help="Description of what this specific command does", + func=lazy_load_command("path.to.python.function.for.command"), + args=Arg( + "--my-arg", + help="Description of my arg", + action="store_true", + ), + ), + ] + auth_manager_sub_commands = [ + ActionCommand( + name="auth_manager_subcommand_name", + help="Description of what this specific command does", + func=lazy_load_command("path.to.python.function.for.command"), + args=(), + ), + ] + custom_sub_commands = [ + ActionCommand( + name="custom_subcommand_name", + help="Description of what this specific command does", + func=lazy_load_command("path.to.python.function.for.command"), + args=(), + ), + ] + + return [ + GroupCommand( + name="my_cool_executor", + help="Description of what this group of commands do", + subcommands=executor_sub_commands, + ), + GroupCommand( + name="my_cool_auth_manager", + help="Description of what this group of commands do", + subcommands=auth_manager_sub_commands, + ), + GroupCommand( + name="my_cool_custom_commands", + help="Description of what this group of commands do", + subcommands=custom_sub_commands, + ), + ] + +2. Update ``cli`` section of your provider's ``provider.yaml`` file to point to the function that + returns the list of CLI commands. For example: + + +.. code-block:: yaml + + cli: + - airflow.providers..cli.definition.get_my_cli_commands + +3. Update ``get_provider_info.py`` file of your provider to include the CLI commands in the + returned dictionary. For example: + +.. code-block:: python + + def get_provider_info() -> dict[str, list[str]]: + return { + # ... + "cli": ["airflow.providers..cli.definition.get_my_cli_commands"], + # ... + } + +You can read more about ``provider.yaml`` and ``get_provider_info.py`` in :doc:`/howto/create-custom-providers`. + +Community-Managed Provider CLI Commands +======================================= + +This is a summary of all Apache Airflow Community provided implementations of CLI commands +exposed via community-managed providers. + +.. note:: + For example, if you are using :doc:`KubernetesExecutor ` and you encounter the ``Please define the 'cli' section in the provider.yaml for custom executors to avoid this warning.`` warning during CLI usage, ensure that you have updated to a version of the provider that includes the necessary CLI command definitions as described below. + +Those provided by the community-managed providers: + +.. airflow-cli-commands:: + :tags: None + :header-separator: " diff --git a/providers-summary-docs/howto/create-custom-providers.rst b/providers-summary-docs/howto/create-custom-providers.rst index 3a95adcbdf4b0..68d2da6acdd05 100644 --- a/providers-summary-docs/howto/create-custom-providers.rst +++ b/providers-summary-docs/howto/create-custom-providers.rst @@ -74,6 +74,9 @@ Exposing customized functionality to the Airflow's core: ``airflow/config_templates/config.yml.schema.json`` with configuration contributed by the providers See :doc:`apache-airflow:howto/set-config` for details about setting configuration. +* ``cli`` - this field should contain the list of all the functions that return CLI commands + to be included in Airflow CLI. See :doc:`apache-airflow:cli-and-env-variables-ref` for description of CLI commands. + * ``connection-types`` - this field should contain the list of all the connection types together with hook class names implementing those custom connection types (providing custom extra fields and custom field behaviour). This field is available as of Airflow 2.2.0 and it replaces deprecated diff --git a/providers-summary-docs/index.rst b/providers-summary-docs/index.rst index d6efb80f92af8..df417f75ab37b 100644 --- a/providers-summary-docs/index.rst +++ b/providers-summary-docs/index.rst @@ -64,6 +64,18 @@ Providers can have their own configuration options which allow you to configure You can see all community-managed providers with their own configuration in :doc:`/core-extensions/configurations` +Command Line Interface +'''''''''''''''''''''' + +.. note:: + The Airflow Core version must be ``3.2.0`` or newer to be able to use CLI commands provided by providers. + +Providers can add their own custom CLI commands to Airflow CLI. Those commands will be available +once you install the provider package. + +You can see all community-managed providers with their own CLI commands in +:doc:`/core-extensions/cli-commands`. + Custom connections '''''''''''''''''' diff --git a/providers/.pre-commit-config.yaml b/providers/.pre-commit-config.yaml index 76da0492312e9..7376116a13cf1 100644 --- a/providers/.pre-commit-config.yaml +++ b/providers/.pre-commit-config.yaml @@ -221,6 +221,12 @@ repos: files: ^.*/provider\.yaml$ exclude: ^.*/.venv/.*$ require_serial: true + - id: check-cli-definition-imports + name: Check CLI definition files only import allowed modules + entry: ../scripts/ci/prek/check_cli_definition_imports.py + language: python + files: ^.*/cli/definition\.py$ + require_serial: false - id: check-imports-in-providers name: Check imports in providers entry: ../scripts/ci/prek/check_imports_in_providers.py diff --git a/providers/amazon/docs/cli-ref.rst b/providers/amazon/docs/cli-ref.rst index 0faf35d47a176..ed6ad0733f865 100644 --- a/providers/amazon/docs/cli-ref.rst +++ b/providers/amazon/docs/cli-ref.rst @@ -19,6 +19,6 @@ Amazon CLI Commands =================== .. argparse:: - :module: airflow.providers.amazon.aws.auth_manager.aws_auth_manager + :module: airflow.providers.amazon.aws.cli.definition :func: get_parser :prog: airflow diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index c9c75b3434dd2..48103930e7327 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -1299,5 +1299,8 @@ executors: auth-managers: - airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager +cli: + - airflow.providers.amazon.aws.cli.definition.get_aws_cli_commands + queues: - airflow.providers.amazon.aws.queues.sqs.SqsMessageQueueProvider diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 1e499a46c1494..df658be1f0c85 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import argparse from collections import defaultdict from collections.abc import Sequence from functools import cached_property @@ -27,16 +26,13 @@ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager -from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand +from airflow.cli.cli_config import CLICommand from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities from airflow.providers.amazon.aws.auth_manager.avp.facade import ( AwsAuthManagerAmazonVerifiedPermissionsFacade, IsAuthorizedRequest, ) -from airflow.providers.amazon.aws.auth_manager.cli.definition import ( - AWS_AUTH_MANAGER_COMMANDS, -) from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.common.compat.sdk import conf @@ -62,6 +58,7 @@ VariableDetails, ) from airflow.api_fastapi.common.types import MenuItem + from airflow.cli.cli_config import CLICommand class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): @@ -468,13 +465,9 @@ def get_url_login(self, **kwargs) -> str: @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" - return [ - GroupCommand( - name="aws-auth-manager", - help="Manage resources used by AWS auth manager", - subcommands=AWS_AUTH_MANAGER_COMMANDS, - ), - ] + from airflow.providers.amazon.aws.cli.definition import get_aws_cli_commands + + return get_aws_cli_commands() def get_fastapi_app(self) -> FastAPI | None: from airflow.providers.amazon.aws.auth_manager.routes.login import login_router @@ -515,14 +508,3 @@ def _check_avp_schema_version(self): "Please update it to its latest version. " "See doc: https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/auth-manager/setup/amazon-verified-permissions.html#update-the-policy-store-schema." ) - - -def get_parser() -> argparse.ArgumentParser: - """Generate documentation; used by Sphinx argparse.""" - from airflow.cli.cli_parser import AirflowHelpFormatter, _add_command - - parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) - subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") - for group_command in AwsAuthManager.get_cli_commands(): - _add_command(subparsers, group_command) - return parser diff --git a/providers/amazon/src/airflow/providers/amazon/aws/cli/__init__.py b/providers/amazon/src/airflow/providers/amazon/aws/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/definition.py b/providers/amazon/src/airflow/providers/amazon/aws/cli/definition.py similarity index 68% rename from providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/definition.py rename to providers/amazon/src/airflow/providers/amazon/aws/cli/definition.py index 4d50fdea68ef8..bc86d27cb0a77 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/definition.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/cli/definition.py @@ -17,12 +17,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.cli.cli_config import ( ActionCommand, Arg, lazy_load_command, ) +if TYPE_CHECKING: + import argparse + ############ # # ARGS # # ############ @@ -58,3 +63,31 @@ args=(ARG_POLICY_STORE_ID, ARG_DRY_RUN), ), ) + + +def get_aws_cli_commands(): + """Return CLI commands for AWS auth manager.""" + from airflow.cli.cli_config import GroupCommand + + return [ + GroupCommand( + name="aws-auth-manager", + help="Manage resources used by AWS auth manager", + subcommands=AWS_AUTH_MANAGER_COMMANDS, + ), + ] + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx argparse. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_aws_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 1b584cf92c6c4..a85c6668b945f 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -1390,5 +1390,6 @@ def get_provider_info(): }, "executors": ["airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"], "auth-managers": ["airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager"], + "cli": ["airflow.providers.amazon.aws.cli.definition.get_aws_cli_commands"], "queues": ["airflow.providers.amazon.aws.queues.sqs.SqsMessageQueueProvider"], } diff --git a/providers/amazon/tests/unit/amazon/aws/cli/__init__.py b/providers/amazon/tests/unit/amazon/aws/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/cli/test_definition.py b/providers/amazon/tests/unit/amazon/aws/cli/test_definition.py similarity index 92% rename from providers/amazon/tests/unit/amazon/aws/auth_manager/cli/test_definition.py rename to providers/amazon/tests/unit/amazon/aws/cli/test_definition.py index 426a991958308..89d31da492466 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/cli/test_definition.py +++ b/providers/amazon/tests/unit/amazon/aws/cli/test_definition.py @@ -23,7 +23,7 @@ if not AIRFLOW_V_3_0_PLUS: pytest.skip("AWS auth manager is only compatible with Airflow >= 3.0.0", allow_module_level=True) -from airflow.providers.amazon.aws.auth_manager.cli.definition import AWS_AUTH_MANAGER_COMMANDS +from airflow.providers.amazon.aws.cli.definition import AWS_AUTH_MANAGER_COMMANDS class TestAwsCliDefinition: diff --git a/providers/celery/docs/cli-ref.rst b/providers/celery/docs/cli-ref.rst index f2e40874f90c4..d0e419564dbb3 100644 --- a/providers/celery/docs/cli-ref.rst +++ b/providers/celery/docs/cli-ref.rst @@ -25,6 +25,6 @@ Celery Executor Commands .. argparse:: - :module: airflow.providers.celery.executors.celery_executor - :func: _get_parser + :module: airflow.providers.celery.cli.definition + :func: get_parser :prog: airflow diff --git a/providers/celery/provider.yaml b/providers/celery/provider.yaml index 5677d37fca6d0..0645d0f355644 100644 --- a/providers/celery/provider.yaml +++ b/providers/celery/provider.yaml @@ -97,6 +97,9 @@ executors: - airflow.providers.celery.executors.celery_executor.CeleryExecutor - airflow.providers.celery.executors.celery_kubernetes_executor.CeleryKubernetesExecutor +cli: + - airflow.providers.celery.cli.definition.get_celery_cli_commands + config: celery_kubernetes_executor: description: | diff --git a/providers/celery/src/airflow/providers/celery/cli/definition.py b/providers/celery/src/airflow/providers/celery/cli/definition.py new file mode 100644 index 0000000000000..964deb4573f47 --- /dev/null +++ b/providers/celery/src/airflow/providers/celery/cli/definition.py @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CLI commands for Celery executor.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.cli.cli_config import ( + ARG_DAEMON, + ARG_LOG_FILE, + ARG_PID, + ARG_SKIP_SERVE_LOGS, + ARG_STDERR, + ARG_STDOUT, + ARG_VERBOSE, + ActionCommand, + Arg, + GroupCommand, + lazy_load_command, +) +from airflow.configuration import conf + +if TYPE_CHECKING: + import argparse + +# flower cli args +ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") +ARG_FLOWER_HOSTNAME = Arg( + ("-H", "--hostname"), + default=conf.get("celery", "FLOWER_HOST"), + help="Set the hostname on which to run the server", +) +ARG_FLOWER_PORT = Arg( + ("-p", "--port"), + default=conf.getint("celery", "FLOWER_PORT"), + type=int, + help="The port on which to run the server", +) +ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") +ARG_FLOWER_URL_PREFIX = Arg( + ("-u", "--url-prefix"), + default=conf.get("celery", "FLOWER_URL_PREFIX"), + help="URL prefix for Flower", +) +ARG_FLOWER_BASIC_AUTH = Arg( + ("-A", "--basic-auth"), + default=conf.get("celery", "FLOWER_BASIC_AUTH"), + help=( + "Securing Flower with Basic Authentication. " + "Accepts user:password pairs separated by a comma. " + "Example: flower_basic_auth = user1:password1,user2:password2" + ), +) + +# worker cli args +ARG_AUTOSCALE = Arg(("-a", "--autoscale"), help="Minimum and Maximum number of worker to autoscale") +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve", + default=conf.get("operators", "DEFAULT_QUEUE"), +) +ARG_CONCURRENCY = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes", + default=conf.getint("celery", "worker_concurrency"), +) +ARG_CELERY_HOSTNAME = Arg( + ("-H", "--celery-hostname"), + help="Set the hostname of celery worker if you have multiple workers on a single machine", +) +ARG_UMASK = Arg( + ("-u", "--umask"), + help="Set the umask of celery worker in daemon mode", +) + +ARG_WITHOUT_MINGLE = Arg( + ("--without-mingle",), + default=False, + help="Don't synchronize with other workers at start-up", + action="store_true", +) +ARG_WITHOUT_GOSSIP = Arg( + ("--without-gossip",), + default=False, + help="Don't subscribe to other workers events", + action="store_true", +) +ARG_OUTPUT = Arg( + ( + "-o", + "--output", + ), + help="Output format. Allowed values: json, yaml, plain, table (default: table)", + metavar="(table, json, yaml, plain)", + choices=("table", "json", "yaml", "plain"), + default="table", +) +ARG_FULL_CELERY_HOSTNAME = Arg( + ("-H", "--celery-hostname"), + required=True, + help="Specify the full celery hostname. example: celery@hostname", +) +ARG_REQUIRED_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve", + required=True, +) +ARG_YES = Arg( + ("-y", "--yes"), + help="Do not prompt to confirm. Use with care!", + action="store_true", + default=False, +) + +CELERY_CLI_COMMAND_PATH = "airflow.providers.celery.cli.celery_command" + +CELERY_COMMANDS = ( + ActionCommand( + name="worker", + help="Start a Celery worker node", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.worker"), + args=( + ARG_QUEUES, + ARG_CONCURRENCY, + ARG_CELERY_HOSTNAME, + ARG_PID, + ARG_DAEMON, + ARG_UMASK, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_AUTOSCALE, + ARG_SKIP_SERVE_LOGS, + ARG_WITHOUT_MINGLE, + ARG_WITHOUT_GOSSIP, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="flower", + help="Start a Celery Flower", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.flower"), + args=( + ARG_FLOWER_HOSTNAME, + ARG_FLOWER_PORT, + ARG_FLOWER_CONF, + ARG_FLOWER_URL_PREFIX, + ARG_FLOWER_BASIC_AUTH, + ARG_BROKER_API, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="stop", + help="Stop the Celery worker gracefully", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.stop_worker"), + args=(ARG_PID, ARG_VERBOSE), + ), + ActionCommand( + name="list-workers", + help="List active celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.list_workers"), + args=(ARG_OUTPUT,), + ), + ActionCommand( + name="shutdown-worker", + help="Request graceful shutdown of celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_worker"), + args=(ARG_FULL_CELERY_HOSTNAME,), + ), + ActionCommand( + name="shutdown-all-workers", + help="Request graceful shutdown of all active celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_all_workers"), + args=(ARG_YES,), + ), + ActionCommand( + name="add-queue", + help="Subscribe Celery worker to specified queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.add_queue"), + args=( + ARG_REQUIRED_QUEUES, + ARG_FULL_CELERY_HOSTNAME, + ), + ), + ActionCommand( + name="remove-queue", + help="Unsubscribe Celery worker from specified queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_queue"), + args=( + ARG_REQUIRED_QUEUES, + ARG_FULL_CELERY_HOSTNAME, + ), + ), + ActionCommand( + name="remove-all-queues", + help="Unsubscribe Celery worker from all its active queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_all_queues"), + args=(ARG_FULL_CELERY_HOSTNAME,), + ), +) + +CELERY_CLI_COMMANDS = [ + GroupCommand( + name="celery", + help="Celery components", + description=( + "Start celery components. Works only when using CeleryExecutor. For more information, " + "see https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html" + ), + subcommands=CELERY_COMMANDS, + ), +] + + +def get_celery_cli_commands(): + """Return CLI commands for Celery executor.""" + return CELERY_CLI_COMMANDS + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_celery_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 367a8c7969ae4..6e3a322d0cac9 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -37,19 +37,6 @@ from celery import states as celery_states from deprecated import deprecated -from airflow.cli.cli_config import ( - ARG_DAEMON, - ARG_LOG_FILE, - ARG_PID, - ARG_SKIP_SERVE_LOGS, - ARG_STDERR, - ARG_STDOUT, - ARG_VERBOSE, - ActionCommand, - Arg, - GroupCommand, - lazy_load_command, -) from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor @@ -64,11 +51,11 @@ if TYPE_CHECKING: - import argparse from collections.abc import Sequence from sqlalchemy.orm import Session + from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -93,190 +80,6 @@ def __getattr__(name): """ -# flower cli args -ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API") -ARG_FLOWER_HOSTNAME = Arg( - ("-H", "--hostname"), - default=conf.get("celery", "FLOWER_HOST"), - help="Set the hostname on which to run the server", -) -ARG_FLOWER_PORT = Arg( - ("-p", "--port"), - default=conf.getint("celery", "FLOWER_PORT"), - type=int, - help="The port on which to run the server", -) -ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower") -ARG_FLOWER_URL_PREFIX = Arg( - ("-u", "--url-prefix"), - default=conf.get("celery", "FLOWER_URL_PREFIX"), - help="URL prefix for Flower", -) -ARG_FLOWER_BASIC_AUTH = Arg( - ("-A", "--basic-auth"), - default=conf.get("celery", "FLOWER_BASIC_AUTH"), - help=( - "Securing Flower with Basic Authentication. " - "Accepts user:password pairs separated by a comma. " - "Example: flower_basic_auth = user1:password1,user2:password2" - ), -) - -# worker cli args -ARG_AUTOSCALE = Arg(("-a", "--autoscale"), help="Minimum and Maximum number of worker to autoscale") -ARG_QUEUES = Arg( - ("-q", "--queues"), - help="Comma delimited list of queues to serve", - default=conf.get("operators", "DEFAULT_QUEUE"), -) -ARG_CONCURRENCY = Arg( - ("-c", "--concurrency"), - type=int, - help="The number of worker processes", - default=conf.getint("celery", "worker_concurrency"), -) -ARG_CELERY_HOSTNAME = Arg( - ("-H", "--celery-hostname"), - help="Set the hostname of celery worker if you have multiple workers on a single machine", -) -ARG_UMASK = Arg( - ("-u", "--umask"), - help="Set the umask of celery worker in daemon mode", -) - -ARG_WITHOUT_MINGLE = Arg( - ("--without-mingle",), - default=False, - help="Don't synchronize with other workers at start-up", - action="store_true", -) -ARG_WITHOUT_GOSSIP = Arg( - ("--without-gossip",), - default=False, - help="Don't subscribe to other workers events", - action="store_true", -) -ARG_OUTPUT = Arg( - ( - "-o", - "--output", - ), - help="Output format. Allowed values: json, yaml, plain, table (default: table)", - metavar="(table, json, yaml, plain)", - choices=("table", "json", "yaml", "plain"), - default="table", -) -ARG_FULL_CELERY_HOSTNAME = Arg( - ("-H", "--celery-hostname"), - required=True, - help="Specify the full celery hostname. example: celery@hostname", -) -ARG_REQUIRED_QUEUES = Arg( - ("-q", "--queues"), - help="Comma delimited list of queues to serve", - required=True, -) -ARG_YES = Arg( - ("-y", "--yes"), - help="Do not prompt to confirm. Use with care!", - action="store_true", - default=False, -) - -CELERY_CLI_COMMAND_PATH = "airflow.providers.celery.cli.celery_command" - -CELERY_COMMANDS = ( - ActionCommand( - name="worker", - help="Start a Celery worker node", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.worker"), - args=( - ARG_QUEUES, - ARG_CONCURRENCY, - ARG_CELERY_HOSTNAME, - ARG_PID, - ARG_DAEMON, - ARG_UMASK, - ARG_STDOUT, - ARG_STDERR, - ARG_LOG_FILE, - ARG_AUTOSCALE, - ARG_SKIP_SERVE_LOGS, - ARG_WITHOUT_MINGLE, - ARG_WITHOUT_GOSSIP, - ARG_VERBOSE, - ), - ), - ActionCommand( - name="flower", - help="Start a Celery Flower", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.flower"), - args=( - ARG_FLOWER_HOSTNAME, - ARG_FLOWER_PORT, - ARG_FLOWER_CONF, - ARG_FLOWER_URL_PREFIX, - ARG_FLOWER_BASIC_AUTH, - ARG_BROKER_API, - ARG_PID, - ARG_DAEMON, - ARG_STDOUT, - ARG_STDERR, - ARG_LOG_FILE, - ARG_VERBOSE, - ), - ), - ActionCommand( - name="stop", - help="Stop the Celery worker gracefully", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.stop_worker"), - args=(ARG_PID, ARG_VERBOSE), - ), - ActionCommand( - name="list-workers", - help="List active celery workers", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.list_workers"), - args=(ARG_OUTPUT,), - ), - ActionCommand( - name="shutdown-worker", - help="Request graceful shutdown of celery workers", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_worker"), - args=(ARG_FULL_CELERY_HOSTNAME,), - ), - ActionCommand( - name="shutdown-all-workers", - help="Request graceful shutdown of all active celery workers", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_all_workers"), - args=(ARG_YES,), - ), - ActionCommand( - name="add-queue", - help="Subscribe Celery worker to specified queues", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.add_queue"), - args=( - ARG_REQUIRED_QUEUES, - ARG_FULL_CELERY_HOSTNAME, - ), - ), - ActionCommand( - name="remove-queue", - help="Unsubscribe Celery worker from specified queues", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_queue"), - args=( - ARG_REQUIRED_QUEUES, - ARG_FULL_CELERY_HOSTNAME, - ), - ), - ActionCommand( - name="remove-all-queues", - help="Unsubscribe Celery worker from all its active queues", - func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_all_queues"), - args=(ARG_FULL_CELERY_HOSTNAME,), - ), -) - - class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. @@ -553,17 +356,9 @@ def revoke_task(self, *, ti: TaskInstance): @staticmethod def get_cli_commands() -> list[GroupCommand]: - return [ - GroupCommand( - name="celery", - help="Celery components", - description=( - "Start celery components. Works only when using CeleryExecutor. For more information, " - "see https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html" - ), - subcommands=CELERY_COMMANDS, - ), - ] + from airflow.providers.celery.cli.definition import get_celery_cli_commands + + return get_celery_cli_commands() def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads @@ -572,12 +367,3 @@ def queue_workload(self, workload: workloads.All, session: Session | None) -> No raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") ti = workload.ti self.queued_tasks[ti.key] = workload - - -def _get_parser() -> argparse.ArgumentParser: - """ - Generate documentation; used by Sphinx. - - :meta private: - """ - return CeleryExecutor._get_parser() diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 78025db44d038..49ae5b35b6f52 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -24,26 +24,22 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowOptionalProviderFeatureException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.executors.celery_executor import AIRFLOW_V_3_0_PLUS, CeleryExecutor - -try: - from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor -except ImportError as e: - raise AirflowOptionalProviderFeatureException(e) - from airflow.utils.providers_configuration_loader import providers_configuration_loaded if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest + from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType from airflow.models.taskinstance import ( # type: ignore[attr-defined] SimpleTaskInstance, TaskInstance, ) from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor CommandType = Sequence[str] @@ -330,5 +326,8 @@ def send_callback(self, request: CallbackRequest) -> None: self.callback_sink.send(request) @staticmethod - def get_cli_commands() -> list: - return CeleryExecutor.get_cli_commands() + KubernetesExecutor.get_cli_commands() + def get_cli_commands() -> list[GroupCommand]: + from airflow.providers.celery.cli.definition import get_celery_cli_commands + from airflow.providers.cncf.kubernetes.cli.definition import get_kubernetes_cli_commands + + return get_celery_cli_commands() + get_kubernetes_cli_commands() diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py b/providers/celery/src/airflow/providers/celery/get_provider_info.py index 9c4b5cda2191b..48097457f5d81 100644 --- a/providers/celery/src/airflow/providers/celery/get_provider_info.py +++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py @@ -44,6 +44,7 @@ def get_provider_info(): "airflow.providers.celery.executors.celery_executor.CeleryExecutor", "airflow.providers.celery.executors.celery_kubernetes_executor.CeleryKubernetesExecutor", ], + "cli": ["airflow.providers.celery.cli.definition.get_celery_cli_commands"], "config": { "celery_kubernetes_executor": { "description": "This section only applies if you are using the ``CeleryKubernetesExecutor`` in\n``[core]`` section above\n", diff --git a/providers/celery/tests/unit/celery/cli/test_definition.py b/providers/celery/tests/unit/celery/cli/test_definition.py new file mode 100644 index 0000000000000..724dca8db3284 --- /dev/null +++ b/providers/celery/tests/unit/celery/cli/test_definition.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib + +import pytest + +from airflow.cli import cli_parser +from airflow.providers.celery.cli.definition import CELERY_CLI_COMMANDS, CELERY_COMMANDS + +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS + + +class TestCeleryCliDefinition: + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + else: + with conf_vars( + { + ( + "core", + "executor", + ): "CeleryExecutor", + } + ): + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + + def test_celery_cli_commands_count(self): + """Test that CELERY_CLI_COMMANDS contains exactly 1 GroupCommand.""" + assert len(CELERY_CLI_COMMANDS) == 1 + + def test_celery_commands_count(self): + """Test that CELERY_COMMANDS contains all 9 subcommands.""" + assert len(CELERY_COMMANDS) == 9 + + @pytest.mark.parametrize( + "command", + [ + "worker", + "flower", + "stop", + "list-workers", + "shutdown-worker", + "shutdown-all-workers", + "add-queue", + "remove-queue", + "remove-all-queues", + ], + ) + def test_celery_subcommands_defined(self, command): + """Test that all celery subcommands are properly defined.""" + params = ["celery", command, "--help"] + with pytest.raises(SystemExit) as exc_info: + self.arg_parser.parse_args(params) + # --help exits with code 0 + assert exc_info.value.code == 0 + + def test_worker_command_args(self): + """Test worker command with various arguments.""" + params = [ + "celery", + "worker", + "--queues", + "queue1,queue2", + "--concurrency", + "4", + "--celery-hostname", + "worker1", + ] + args = self.arg_parser.parse_args(params) + assert args.queues == "queue1,queue2" + assert args.concurrency == 4 + assert args.celery_hostname == "worker1" + + def test_flower_command_args(self): + """Test flower command with various arguments.""" + params = [ + "celery", + "flower", + "--hostname", + "localhost", + "--port", + "5555", + "--url-prefix", + "/flower", + ] + args = self.arg_parser.parse_args(params) + assert args.hostname == "localhost" + assert args.port == 5555 + assert args.url_prefix == "/flower" + + def test_list_workers_command_args(self): + """Test list-workers command with output format.""" + params = ["celery", "list-workers", "--output", "json"] + args = self.arg_parser.parse_args(params) + assert args.output == "json" + + def test_shutdown_worker_command_args(self): + """Test shutdown-worker command with celery hostname.""" + params = ["celery", "shutdown-worker", "--celery-hostname", "celery@worker1"] + args = self.arg_parser.parse_args(params) + assert args.celery_hostname == "celery@worker1" + + def test_shutdown_all_workers_command_args(self): + """Test shutdown-all-workers command with yes flag.""" + params = ["celery", "shutdown-all-workers", "--yes"] + args = self.arg_parser.parse_args(params) + assert args.yes is True + + def test_add_queue_command_args(self): + """Test add-queue command with required arguments.""" + params = [ + "celery", + "add-queue", + "--queues", + "new_queue", + "--celery-hostname", + "celery@worker1", + ] + args = self.arg_parser.parse_args(params) + assert args.queues == "new_queue" + assert args.celery_hostname == "celery@worker1" + + def test_remove_queue_command_args(self): + """Test remove-queue command with required arguments.""" + params = [ + "celery", + "remove-queue", + "--queues", + "old_queue", + "--celery-hostname", + "celery@worker1", + ] + args = self.arg_parser.parse_args(params) + assert args.queues == "old_queue" + assert args.celery_hostname == "celery@worker1" + + def test_remove_all_queues_command_args(self): + """Test remove-all-queues command with celery hostname.""" + params = ["celery", "remove-all-queues", "--celery-hostname", "celery@worker1"] + args = self.arg_parser.parse_args(params) + assert args.celery_hostname == "celery@worker1" + + def test_stop_command_args(self): + """Test stop command with pid argument.""" + params = ["celery", "stop", "--pid", "/path/to/pid"] + args = self.arg_parser.parse_args(params) + assert args.pid == "/path/to/pid" diff --git a/providers/cncf/kubernetes/docs/cli-ref.rst b/providers/cncf/kubernetes/docs/cli-ref.rst index daa7d047f3467..9c7d1f3ef9a6a 100644 --- a/providers/cncf/kubernetes/docs/cli-ref.rst +++ b/providers/cncf/kubernetes/docs/cli-ref.rst @@ -24,6 +24,6 @@ Kubernetes Executor Commands the core Airflow documentation for the list of CLI commands and parameters available. .. argparse:: - :module: airflow.providers.cncf.kubernetes.executors.kubernetes_executor - :func: _get_parser + :module: airflow.providers.cncf.kubernetes.cli.definition + :func: get_parser :prog: airflow diff --git a/providers/cncf/kubernetes/provider.yaml b/providers/cncf/kubernetes/provider.yaml index b7a64c37d64e3..ee212a2e92a6a 100644 --- a/providers/cncf/kubernetes/provider.yaml +++ b/providers/cncf/kubernetes/provider.yaml @@ -398,3 +398,6 @@ config: executors: - airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor + +cli: + - airflow.providers.cncf.kubernetes.cli.definition.get_kubernetes_cli_commands diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/definition.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/definition.py new file mode 100644 index 0000000000000..210187bf9ced8 --- /dev/null +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/definition.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.cli.cli_config import ( + ARG_DAG_ID, + ARG_OUTPUT_PATH, + ARG_VERBOSE, + ActionCommand, + Arg, + GroupCommand, + lazy_load_command, + positive_int, +) +from airflow.configuration import conf +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS + +if TYPE_CHECKING: + import argparse + + +try: + from airflow.cli.cli_config import ARG_LOGICAL_DATE +except ImportError: # 2.x compatibility. + from airflow.cli.cli_config import ( # type: ignore[attr-defined, no-redef] + ARG_EXECUTION_DATE as ARG_LOGICAL_DATE, + ) + +if AIRFLOW_V_3_0_PLUS: + from airflow.cli.cli_config import ARG_BUNDLE_NAME + + ARG_COMPAT = ARG_BUNDLE_NAME +else: + from airflow.cli.cli_config import ARG_SUBDIR # type: ignore[attr-defined] + + ARG_COMPAT = ARG_SUBDIR + +# CLI Args +ARG_NAMESPACE = Arg( + ("--namespace",), + default=conf.get("kubernetes_executor", "namespace"), + help="Kubernetes Namespace. Default value is `[kubernetes] namespace` in configuration.", +) + +ARG_MIN_PENDING_MINUTES = Arg( + ("--min-pending-minutes",), + default=30, + type=positive_int(allow_zero=False), + help=( + "Pending pods created before the time interval are to be cleaned up, " + "measured in minutes. Default value is 30(m). The minimum value is 5(m)." + ), +) + +# CLI Commands +KUBERNETES_COMMANDS = ( + ActionCommand( + name="cleanup-pods", + help=( + "Clean up Kubernetes pods " + "(created by KubernetesExecutor/KubernetesPodOperator) " + "in evicted/failed/succeeded/pending states" + ), + func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.cleanup_pods"), + args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE), + ), + ActionCommand( + name="generate-dag-yaml", + help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without " + "launching into a cluster", + func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.generate_pod_yaml"), + args=(ARG_DAG_ID, ARG_LOGICAL_DATE, ARG_COMPAT, ARG_OUTPUT_PATH, ARG_VERBOSE), + ), +) + + +def get_kubernetes_cli_commands() -> list[GroupCommand]: + return [ + GroupCommand( + name="kubernetes", + help="Tools to help run the KubernetesExecutor", + subcommands=KUBERNETES_COMMANDS, + ) + ] + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_kubernetes_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 4756405523cda..fb1b7b54d04c9 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -39,25 +39,6 @@ from kubernetes.dynamic import DynamicClient from sqlalchemy import select -from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator -from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS - -try: - from airflow.cli.cli_config import ARG_LOGICAL_DATE -except ImportError: # 2.x compatibility. - from airflow.cli.cli_config import ( # type: ignore[attr-defined, no-redef] - ARG_EXECUTION_DATE as ARG_LOGICAL_DATE, - ) -from airflow.cli.cli_config import ( - ARG_DAG_ID, - ARG_OUTPUT_PATH, - ARG_VERBOSE, - ActionCommand, - Arg, - GroupCommand, - lazy_load_command, - positive_int, -) from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor @@ -70,19 +51,21 @@ ) from airflow.providers.cncf.kubernetes.kube_config import KubeConfig from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annotations_to_key +from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.common.compat.sdk import Stats from airflow.utils.log.logging_mixin import remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - import argparse from collections.abc import Sequence from kubernetes import client from kubernetes.client import models as k8s from sqlalchemy.orm import Session + from airflow.cli.cli_config import GroupCommand from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -91,54 +74,6 @@ ) -if AIRFLOW_V_3_0_PLUS: - from airflow.cli.cli_config import ARG_BUNDLE_NAME - - ARG_COMPAT = ARG_BUNDLE_NAME -else: - from airflow.cli.cli_config import ARG_SUBDIR # type: ignore[attr-defined] - - ARG_COMPAT = ARG_SUBDIR - -# CLI Args -ARG_NAMESPACE = Arg( - ("--namespace",), - default=conf.get("kubernetes_executor", "namespace"), - help="Kubernetes Namespace. Default value is `[kubernetes] namespace` in configuration.", -) - -ARG_MIN_PENDING_MINUTES = Arg( - ("--min-pending-minutes",), - default=30, - type=positive_int(allow_zero=False), - help=( - "Pending pods created before the time interval are to be cleaned up, " - "measured in minutes. Default value is 30(m). The minimum value is 5(m)." - ), -) - -# CLI Commands -KUBERNETES_COMMANDS = ( - ActionCommand( - name="cleanup-pods", - help=( - "Clean up Kubernetes pods " - "(created by KubernetesExecutor/KubernetesPodOperator) " - "in evicted/failed/succeeded/pending states" - ), - func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.cleanup_pods"), - args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE), - ), - ActionCommand( - name="generate-dag-yaml", - help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without " - "launching into a cluster", - func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.generate_pod_yaml"), - args=(ARG_DAG_ID, ARG_LOGICAL_DATE, ARG_COMPAT, ARG_OUTPUT_PATH, ARG_VERBOSE), - ), -) - - class KubernetesExecutor(BaseExecutor): """Executor for Kubernetes.""" @@ -812,19 +747,6 @@ def terminate(self): @staticmethod def get_cli_commands() -> list[GroupCommand]: - return [ - GroupCommand( - name="kubernetes", - help="Tools to help run the KubernetesExecutor", - subcommands=KUBERNETES_COMMANDS, - ) - ] - - -def _get_parser() -> argparse.ArgumentParser: - """ - Generate documentation; used by Sphinx. + from airflow.providers.cncf.kubernetes.cli.definition import get_kubernetes_cli_commands - :meta private: - """ - return KubernetesExecutor._get_parser() + return get_kubernetes_cli_commands() diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 0cc7fa9dbc1f1..114da7ec36fe8 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -25,12 +25,12 @@ from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest + from airflow.cli.cli_config import GroupCommand from airflow.executors.base_executor import EventBufferValueType from airflow.executors.local_executor import LocalExecutor from airflow.models.taskinstance import ( # type: ignore[attr-defined] @@ -38,6 +38,7 @@ TaskInstance, TaskInstanceKey, ) + from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor CommandType = Sequence[str] @@ -302,5 +303,7 @@ def send_callback(self, request: CallbackRequest) -> None: self.callback_sink.send(request) @staticmethod - def get_cli_commands() -> list: - return KubernetesExecutor.get_cli_commands() + def get_cli_commands() -> list[GroupCommand]: + from airflow.providers.cncf.kubernetes.cli.definition import get_kubernetes_cli_commands + + return get_kubernetes_cli_commands() diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py index 963178ab645e2..80d37dd17bc6e 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py @@ -286,4 +286,5 @@ def get_provider_info(): }, }, "executors": ["airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor"], + "cli": ["airflow.providers.cncf.kubernetes.cli.definition.get_kubernetes_cli_commands"], } diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/cli/test_definition.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/cli/test_definition.py new file mode 100644 index 0000000000000..4ffd0f224a90f --- /dev/null +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/cli/test_definition.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.cli import cli_parser +from airflow.providers.cncf.kubernetes.cli.definition import ( + KUBERNETES_COMMANDS, + get_kubernetes_cli_commands, +) + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS + + +class TestKubernetesCliDefinition: + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + else: + with patch( + "airflow.executors.executor_loader.ExecutorLoader.get_executor_names", + ) as mock_get_executor_names: + mock_get_executor_names.return_value = [ + MagicMock( + name="KubernetesExecutor", + module_path="airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor", + ) + ] + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + + def test_kubernetes_cli_commands_count(self): + """Test that get_kubernetes_cli_commands returns exactly 1 GroupCommand.""" + commands = get_kubernetes_cli_commands() + assert len(commands) == 1 + + def test_kubernetes_commands_count(self): + """Test that KUBERNETES_COMMANDS contains all 2 subcommands.""" + assert len(KUBERNETES_COMMANDS) == 2 + + @pytest.mark.parametrize( + "command", + [ + "cleanup-pods", + "generate-dag-yaml", + ], + ) + def test_kubernetes_subcommands_defined(self, command): + """Test that all kubernetes subcommands are properly defined.""" + params = ["kubernetes", command, "--help"] + with pytest.raises(SystemExit) as exc_info: + self.arg_parser.parse_args(params) + # --help exits with code 0 + assert exc_info.value.code == 0 + + def test_cleanup_pods_command_args(self): + """Test cleanup-pods command with various arguments.""" + params = [ + "kubernetes", + "cleanup-pods", + "--namespace", + "my-namespace", + "--min-pending-minutes", + "60", + ] + args = self.arg_parser.parse_args(params) + assert args.namespace == "my-namespace" + assert args.min_pending_minutes == 60 + + def test_cleanup_pods_command_default_args(self): + """Test cleanup-pods command with default arguments.""" + params = ["kubernetes", "cleanup-pods"] + args = self.arg_parser.parse_args(params) + # Should use default values from configuration + assert hasattr(args, "namespace") + assert args.min_pending_minutes == 30 + + def test_generate_dag_yaml_command_args(self): + """Test generate-dag-yaml command with various arguments.""" + if AIRFLOW_V_3_0_PLUS: + params = [ + "kubernetes", + "generate-dag-yaml", + "my_dag", + "--logical-date", + "2024-01-01T00:00:00+00:00", + "--output-path", + "/tmp/output", + ] + args = self.arg_parser.parse_args(params) + assert args.logical_date == datetime.fromisoformat("2024-01-01T00:00:00+00:00") + else: + params = [ + "kubernetes", + "generate-dag-yaml", + "--output-path", + "/tmp/output", + "my_dag", + "2024-01-01T00:00:00+00:00", + ] + args = self.arg_parser.parse_args(params) + assert args.execution_date == datetime.fromisoformat("2024-01-01T00:00:00+00:00") + + assert args.dag_id == "my_dag" + assert args.output_path == "/tmp/output" diff --git a/providers/edge3/docs/cli-ref.rst b/providers/edge3/docs/cli-ref.rst index 264c825c191d4..7f9a9044de8fc 100644 --- a/providers/edge3/docs/cli-ref.rst +++ b/providers/edge3/docs/cli-ref.rst @@ -19,6 +19,6 @@ Edge Executor Commands ---------------------- .. argparse:: - :module: airflow.providers.edge3.executors.edge_executor - :func: _get_parser + :module: airflow.providers.edge3.cli.definition + :func: get_parser :prog: airflow diff --git a/providers/edge3/provider.yaml b/providers/edge3/provider.yaml index 482ffa1f2bd16..f0c334349c0d4 100644 --- a/providers/edge3/provider.yaml +++ b/providers/edge3/provider.yaml @@ -59,6 +59,9 @@ plugins: - name: edge_executor plugin-class: airflow.providers.edge3.plugins.edge_executor_plugin.EdgeExecutorPlugin +cli: + - airflow.providers.edge3.cli.definition.get_edge_cli_commands + executors: - airflow.providers.edge3.executors.EdgeExecutor diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py b/providers/edge3/src/airflow/providers/edge3/cli/definition.py new file mode 100644 index 0000000000000..5cd611cc3a438 --- /dev/null +++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg, GroupCommand, lazy_load_command +from airflow.configuration import conf + +if TYPE_CHECKING: + import argparse + + +ARG_CONCURRENCY = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes", + default=conf.getint("edge", "worker_concurrency", fallback=8), +) +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve, serve all queues if not provided.", +) +ARG_EDGE_HOSTNAME = Arg( + ("-H", "--edge-hostname"), + help="Set the hostname of worker if you have multiple workers on a single machine", +) +ARG_REQUIRED_EDGE_HOSTNAME = Arg( + ("-H", "--edge-hostname"), + help="Set the hostname of worker if you have multiple workers on a single machine", + required=True, +) +ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off")) +ARG_MAINTENANCE_COMMENT = Arg( + ("-c", "--comments"), + help="Maintenance comments to report reason. Required if maintenance is turned on.", +) +ARG_REQUIRED_MAINTENANCE_COMMENT = Arg( + ("-c", "--comments"), + help="Maintenance comments to report reason. Required if enabling maintenance", + required=True, +) +ARG_QUEUES_MANAGE = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to add or remove.", + required=True, +) +ARG_WAIT_MAINT = Arg( + ("-w", "--wait"), + default=False, + help="Wait until edge worker has reached desired state.", + action="store_true", +) +ARG_WAIT_STOP = Arg( + ("-w", "--wait"), + default=False, + help="Wait until edge worker is shut down.", + action="store_true", +) +ARG_OUTPUT = Arg( + ( + "-o", + "--output", + ), + help="Output format. Allowed values: json, yaml, plain, table (default: table)", + metavar="(table, json, yaml, plain)", + choices=("table", "json", "yaml", "plain"), + default="table", +) +ARG_STATE = Arg( + ( + "-s", + "--state", + ), + nargs="+", + help="State of the edge worker", +) + +ARG_DAEMON = Arg( + ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" +) +ARG_UMASK = Arg( + ("-u", "--umask"), + help="Set the umask of edge worker in daemon mode", +) +ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to this file if run in daemon mode") +ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file if run in daemon mode") +ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file if run in daemon mode") +ARG_YES = Arg( + ("-y", "--yes"), + help="Skip confirmation prompt and proceed with shutdown", + action="store_true", + default=False, +) + +EDGE_COMMANDS: list[ActionCommand] = [ + ActionCommand( + name="worker", + help="Start Airflow Edge Worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.worker"), + args=( + ARG_CONCURRENCY, + ARG_QUEUES, + ARG_EDGE_HOSTNAME, + ARG_PID, + ARG_VERBOSE, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_LOG_FILE, + ARG_UMASK, + ), + ), + ActionCommand( + name="status", + help="Check for Airflow Local Edge Worker status.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.status"), + args=( + ARG_PID, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="maintenance", + help="Set or Unset maintenance mode of local edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.maintenance"), + args=( + ARG_MAINTENANCE, + ARG_MAINTENANCE_COMMENT, + ARG_WAIT_MAINT, + ARG_PID, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="stop", + help="Stop a running local Airflow Edge Worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.stop"), + args=( + ARG_WAIT_STOP, + ARG_PID, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="list-workers", + help="Query the db to list all registered edge workers.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.list_edge_workers"), + args=( + ARG_OUTPUT, + ARG_STATE, + ), + ), + ActionCommand( + name="remote-edge-worker-request-maintenance", + help="Put remote edge worker on maintenance.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.put_remote_worker_on_maintenance"), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_REQUIRED_MAINTENANCE_COMMENT, + ), + ), + ActionCommand( + name="remote-edge-worker-exit-maintenance", + help="Remove remote edge worker from maintenance.", + func=lazy_load_command( + "airflow.providers.edge3.cli.edge_command.remove_remote_worker_from_maintenance" + ), + args=(ARG_REQUIRED_EDGE_HOSTNAME,), + ), + ActionCommand( + name="remote-edge-worker-update-maintenance-comment", + help="Update maintenance comments of the remote edge worker.", + func=lazy_load_command( + "airflow.providers.edge3.cli.edge_command.remote_worker_update_maintenance_comment" + ), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_REQUIRED_MAINTENANCE_COMMENT, + ), + ), + ActionCommand( + name="remove-remote-edge-worker", + help="Remove remote edge worker entry from db.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.remove_remote_worker"), + args=(ARG_REQUIRED_EDGE_HOSTNAME,), + ), + ActionCommand( + name="shutdown-remote-edge-worker", + help="Initiate the shutdown of the remote edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.remote_worker_request_shutdown"), + args=(ARG_REQUIRED_EDGE_HOSTNAME,), + ), + ActionCommand( + name="add-worker-queues", + help="Add queues to an edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.add_worker_queues"), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_QUEUES_MANAGE, + ), + ), + ActionCommand( + name="remove-worker-queues", + help="Remove queues from an edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.remove_worker_queues"), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_QUEUES_MANAGE, + ), + ), + ActionCommand( + name="shutdown-all-workers", + help="Request graceful shutdown of all edge workers.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"), + args=(ARG_YES,), + ), +] + + +def get_edge_cli_commands() -> list[GroupCommand]: + return [ + GroupCommand( + name="edge", + help="Edge Worker components", + description=( + "Start and manage Edge Worker. Works only when using EdgeExecutor. For more information, " + "see https://airflow.apache.org/docs/apache-airflow-providers-edge3/stable/edge_executor.html" + ), + subcommands=EDGE_COMMANDS, + ), + ] + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_edge_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py index 63296917683f4..d3a8067f4bdb9 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py @@ -30,7 +30,6 @@ import psutil from airflow import settings -from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf @@ -422,206 +421,3 @@ def remove_worker_queues(args) -> None: except TypeError as e: logger.error(str(e)) raise SystemExit - - -ARG_CONCURRENCY = Arg( - ("-c", "--concurrency"), - type=int, - help="The number of worker processes", - default=conf.getint("edge", "worker_concurrency", fallback=8), -) -ARG_QUEUES = Arg( - ("-q", "--queues"), - help="Comma delimited list of queues to serve, serve all queues if not provided.", -) -ARG_EDGE_HOSTNAME = Arg( - ("-H", "--edge-hostname"), - help="Set the hostname of worker if you have multiple workers on a single machine", -) -ARG_REQUIRED_EDGE_HOSTNAME = Arg( - ("-H", "--edge-hostname"), - help="Set the hostname of worker if you have multiple workers on a single machine", - required=True, -) -ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off")) -ARG_MAINTENANCE_COMMENT = Arg( - ("-c", "--comments"), - help="Maintenance comments to report reason. Required if maintenance is turned on.", -) -ARG_REQUIRED_MAINTENANCE_COMMENT = Arg( - ("-c", "--comments"), - help="Maintenance comments to report reason. Required if enabling maintenance", - required=True, -) -ARG_QUEUES_MANAGE = Arg( - ("-q", "--queues"), - help="Comma delimited list of queues to add or remove.", - required=True, -) -ARG_WAIT_MAINT = Arg( - ("-w", "--wait"), - default=False, - help="Wait until edge worker has reached desired state.", - action="store_true", -) -ARG_WAIT_STOP = Arg( - ("-w", "--wait"), - default=False, - help="Wait until edge worker is shut down.", - action="store_true", -) -ARG_OUTPUT = Arg( - ( - "-o", - "--output", - ), - help="Output format. Allowed values: json, yaml, plain, table (default: table)", - metavar="(table, json, yaml, plain)", - choices=("table", "json", "yaml", "plain"), - default="table", -) -ARG_STATE = Arg( - ( - "-s", - "--state", - ), - nargs="+", - help="State of the edge worker", -) - -ARG_DAEMON = Arg( - ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" -) -ARG_UMASK = Arg( - ("-u", "--umask"), - help="Set the umask of edge worker in daemon mode", -) -ARG_STDERR = Arg(("--stderr",), help="Redirect stderr to this file if run in daemon mode") -ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file if run in daemon mode") -ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file if run in daemon mode") -ARG_YES = Arg( - ("-y", "--yes"), - help="Skip confirmation prompt and proceed with shutdown", - action="store_true", - default=False, -) - -EDGE_COMMANDS: list[ActionCommand] = [ - ActionCommand( - name=worker.__name__, - help=worker.__doc__, - func=worker, - args=( - ARG_CONCURRENCY, - ARG_QUEUES, - ARG_EDGE_HOSTNAME, - ARG_PID, - ARG_VERBOSE, - ARG_DAEMON, - ARG_STDOUT, - ARG_STDERR, - ARG_LOG_FILE, - ARG_UMASK, - ), - ), - ActionCommand( - name=status.__name__, - help=status.__doc__, - func=status, - args=( - ARG_PID, - ARG_VERBOSE, - ), - ), - ActionCommand( - name=maintenance.__name__, - help=maintenance.__doc__, - func=maintenance, - args=( - ARG_MAINTENANCE, - ARG_MAINTENANCE_COMMENT, - ARG_WAIT_MAINT, - ARG_PID, - ARG_VERBOSE, - ), - ), - ActionCommand( - name=stop.__name__, - help=stop.__doc__, - func=stop, - args=( - ARG_WAIT_STOP, - ARG_PID, - ARG_VERBOSE, - ), - ), - ActionCommand( - name="list-workers", - help=list_edge_workers.__doc__, - func=list_edge_workers, - args=( - ARG_OUTPUT, - ARG_STATE, - ), - ), - ActionCommand( - name="remote-edge-worker-request-maintenance", - help=put_remote_worker_on_maintenance.__doc__, - func=put_remote_worker_on_maintenance, - args=( - ARG_REQUIRED_EDGE_HOSTNAME, - ARG_REQUIRED_MAINTENANCE_COMMENT, - ), - ), - ActionCommand( - name="remote-edge-worker-exit-maintenance", - help=remove_remote_worker_from_maintenance.__doc__, - func=remove_remote_worker_from_maintenance, - args=(ARG_REQUIRED_EDGE_HOSTNAME,), - ), - ActionCommand( - name="remote-edge-worker-update-maintenance-comment", - help=remote_worker_update_maintenance_comment.__doc__, - func=remote_worker_update_maintenance_comment, - args=( - ARG_REQUIRED_EDGE_HOSTNAME, - ARG_REQUIRED_MAINTENANCE_COMMENT, - ), - ), - ActionCommand( - name="remove-remote-edge-worker", - help=remove_remote_worker.__doc__, - func=remove_remote_worker, - args=(ARG_REQUIRED_EDGE_HOSTNAME,), - ), - ActionCommand( - name="shutdown-remote-edge-worker", - help=remote_worker_request_shutdown.__doc__, - func=remote_worker_request_shutdown, - args=(ARG_REQUIRED_EDGE_HOSTNAME,), - ), - ActionCommand( - name="add-worker-queues", - help=add_worker_queues.__doc__, - func=add_worker_queues, - args=( - ARG_REQUIRED_EDGE_HOSTNAME, - ARG_QUEUES_MANAGE, - ), - ), - ActionCommand( - name="remove-worker-queues", - help=remove_worker_queues.__doc__, - func=remove_worker_queues, - args=( - ARG_REQUIRED_EDGE_HOSTNAME, - ARG_QUEUES_MANAGE, - ), - ), - ActionCommand( - name="shutdown-all-workers", - help=shutdown_all_workers.__doc__, - func=shutdown_all_workers, - args=(ARG_YES,), - ), -] diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index b82b5df02ed90..60392062469bc 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -27,13 +27,11 @@ from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import Session -from airflow.cli.cli_config import GroupCommand from airflow.configuration import conf from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.sdk import Stats, timezone -from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge3.models.edge_job import EdgeJobModel from airflow.providers.edge3.models.edge_logs import EdgeLogsModel from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, reset_metrics @@ -42,10 +40,9 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - import argparse - from sqlalchemy.engine.base import Engine + from airflow.cli.cli_config import GroupCommand from airflow.models.taskinstancekey import TaskInstanceKey # TODO: Airflow 2 type hints; remove when Airflow 2 support is removed @@ -383,23 +380,6 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @staticmethod def get_cli_commands() -> list[GroupCommand]: - return [ - GroupCommand( - name="edge", - help="Edge Worker components", - description=( - "Start and manage Edge Worker. Works only when using EdgeExecutor. For more information, " - "see https://airflow.apache.org/docs/apache-airflow-providers-edge3/stable/edge_executor.html" - ), - subcommands=EDGE_COMMANDS, - ), - ] - - -def _get_parser() -> argparse.ArgumentParser: - """ - Generate documentation; used by Sphinx. + from airflow.providers.edge3.cli.definition import get_edge_cli_commands - :meta private: - """ - return EdgeExecutor._get_parser() + return get_edge_cli_commands() diff --git a/providers/edge3/src/airflow/providers/edge3/get_provider_info.py b/providers/edge3/src/airflow/providers/edge3/get_provider_info.py index b8a3fb890d0b1..393b8cf9b0810 100644 --- a/providers/edge3/src/airflow/providers/edge3/get_provider_info.py +++ b/providers/edge3/src/airflow/providers/edge3/get_provider_info.py @@ -32,6 +32,7 @@ def get_provider_info(): "plugin-class": "airflow.providers.edge3.plugins.edge_executor_plugin.EdgeExecutorPlugin", } ], + "cli": ["airflow.providers.edge3.cli.definition.get_edge_cli_commands"], "executors": ["airflow.providers.edge3.executors.EdgeExecutor"], "config": { "edge": { diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py b/providers/edge3/tests/unit/edge3/cli/test_definition.py new file mode 100644 index 0000000000000..a99bbc7565e7f --- /dev/null +++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.cli import cli_parser +from airflow.providers.edge3.cli.definition import EDGE_COMMANDS, get_edge_cli_commands + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS + + +class TestEdgeCliDefinition: + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + else: + with patch( + "airflow.executors.executor_loader.ExecutorLoader.get_executor_names", + ) as mock_get_executor_names: + mock_get_executor_names.return_value = [ + MagicMock( + name="EdgeExecutor", module_path="airflow.providers.edge3.executors.EdgeExecutor" + ) + ] + importlib.reload(cli_parser) + cli_parser.get_parser.cache_clear() + self.arg_parser = cli_parser.get_parser() + + def test_edge_cli_commands_count(self): + """Test that get_edge_cli_commands returns exactly 1 GroupCommand.""" + commands = get_edge_cli_commands() + assert len(commands) == 1 + + def test_edge_commands_count(self): + """Test that EDGE_COMMANDS contains all 13 subcommands.""" + assert len(EDGE_COMMANDS) == 13 + + @pytest.mark.parametrize( + "command", + [ + "worker", + "status", + "maintenance", + "stop", + "list-workers", + "remote-edge-worker-request-maintenance", + "remote-edge-worker-exit-maintenance", + "remote-edge-worker-update-maintenance-comment", + "remove-remote-edge-worker", + "shutdown-remote-edge-worker", + "add-worker-queues", + "remove-worker-queues", + "shutdown-all-workers", + ], + ) + def test_edge_subcommands_defined(self, command): + """Test that all edge subcommands are properly defined.""" + params = ["edge", command, "--help"] + with pytest.raises(SystemExit) as exc_info: + self.arg_parser.parse_args(params) + # --help exits with code 0 + assert exc_info.value.code == 0 + + def test_worker_command_args(self): + """Test worker command with various arguments.""" + params = [ + "edge", + "worker", + "--queues", + "queue1,queue2", + "--concurrency", + "4", + "--edge-hostname", + "edge-worker-1", + ] + args = self.arg_parser.parse_args(params) + assert args.queues == "queue1,queue2" + assert args.concurrency == 4 + assert args.edge_hostname == "edge-worker-1" + + def test_status_command_args(self): + """Test status command with pid argument.""" + params = ["edge", "status", "--pid", "/path/to/pid"] + args = self.arg_parser.parse_args(params) + assert args.pid == "/path/to/pid" + + def test_maintenance_command_args_on(self): + """Test maintenance command to enable maintenance mode.""" + params = [ + "edge", + "maintenance", + "on", + "--comments", + "Scheduled maintenance", + "--wait", + ] + args = self.arg_parser.parse_args(params) + assert args.maintenance == "on" + assert args.comments == "Scheduled maintenance" + assert args.wait is True + + def test_maintenance_command_args_off(self): + """Test maintenance command to disable maintenance mode.""" + params = ["edge", "maintenance", "off"] + args = self.arg_parser.parse_args(params) + assert args.maintenance == "off" + + def test_stop_command_args(self): + """Test stop command with wait argument.""" + params = ["edge", "stop", "--wait", "--pid", "/path/to/pid"] + args = self.arg_parser.parse_args(params) + assert args.wait is True + assert args.pid == "/path/to/pid" + + def test_list_workers_command_args(self): + """Test list-workers command with output format and state filter.""" + params = ["edge", "list-workers", "--output", "json", "--state", "running", "maintenance"] + args = self.arg_parser.parse_args(params) + assert args.output == "json" + assert args.state == ["running", "maintenance"] + + def test_remote_edge_worker_request_maintenance_args(self): + """Test remote-edge-worker-request-maintenance command with required arguments.""" + params = [ + "edge", + "remote-edge-worker-request-maintenance", + "--edge-hostname", + "remote-worker-1", + "--comments", + "Emergency maintenance", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.comments == "Emergency maintenance" + + def test_remote_edge_worker_exit_maintenance_args(self): + """Test remote-edge-worker-exit-maintenance command with required hostname.""" + params = [ + "edge", + "remote-edge-worker-exit-maintenance", + "--edge-hostname", + "remote-worker-1", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + + def test_remote_edge_worker_update_maintenance_comment_args(self): + """Test remote-edge-worker-update-maintenance-comment command with required arguments.""" + params = [ + "edge", + "remote-edge-worker-update-maintenance-comment", + "--edge-hostname", + "remote-worker-1", + "--comments", + "Updated maintenance reason", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.comments == "Updated maintenance reason" + + def test_remove_remote_edge_worker_args(self): + """Test remove-remote-edge-worker command with required hostname.""" + params = [ + "edge", + "remove-remote-edge-worker", + "--edge-hostname", + "remote-worker-1", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + + def test_shutdown_remote_edge_worker_args(self): + """Test shutdown-remote-edge-worker command with required hostname.""" + params = [ + "edge", + "shutdown-remote-edge-worker", + "--edge-hostname", + "remote-worker-1", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + + def test_add_worker_queues_args(self): + """Test add-worker-queues command with required arguments.""" + params = [ + "edge", + "add-worker-queues", + "--edge-hostname", + "remote-worker-1", + "--queues", + "queue3,queue4", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.queues == "queue3,queue4" + + def test_remove_worker_queues_args(self): + """Test remove-worker-queues command with required arguments.""" + params = [ + "edge", + "remove-worker-queues", + "--edge-hostname", + "remote-worker-1", + "--queues", + "queue1", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.queues == "queue1" + + def test_shutdown_all_workers_args(self): + """Test shutdown-all-workers command with yes flag.""" + params = ["edge", "shutdown-all-workers", "--yes"] + args = self.arg_parser.parse_args(params) + assert args.yes is True diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index b8d7435a9855f..2189a6979942a 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import argparse import contextlib import importlib import json @@ -31,7 +30,6 @@ from requests import HTTPError, Response from airflow.cli import cli_parser -from airflow.executors import executor_loader from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.cli import edge_command from airflow.providers.edge3.cli.dataclasses import Job @@ -49,6 +47,7 @@ from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS pytest.importorskip("pydantic", minversion="2.0.0") @@ -87,16 +86,22 @@ def returncode(self): class TestEdgeWorker: - parser: argparse.ArgumentParser - - @classmethod - def setup_class(cls): - with conf_vars( - {("core", "executor"): "airflow.providers.edge3.executors.edge_executor.EdgeExecutor"} - ): - importlib.reload(executor_loader) + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: importlib.reload(cli_parser) - cls.parser = cli_parser.get_parser() + self.parser = cli_parser.get_parser() + else: + with patch( + "airflow.executors.executor_loader.ExecutorLoader.get_executor_names", + ) as mock_get_executor_names: + mock_get_executor_names.return_value = [ + MagicMock( + name="EdgeExecutor", module_path="airflow.providers.edge3.executors.EdgeExecutor" + ) + ] + importlib.reload(cli_parser) + self.parser = cli_parser.get_parser() @pytest.fixture def mock_joblist(self, tmp_path: Path) -> list[Job]: diff --git a/providers/fab/docs/cli-ref.rst b/providers/fab/docs/cli-ref.rst index e5e42b8425b0b..84712c03d800c 100644 --- a/providers/fab/docs/cli-ref.rst +++ b/providers/fab/docs/cli-ref.rst @@ -19,6 +19,6 @@ FAB CLI Commands ================ .. argparse:: - :module: airflow.providers.fab.auth_manager.fab_auth_manager + :module: airflow.providers.fab.cli.definition :func: get_parser :prog: airflow diff --git a/providers/fab/provider.yaml b/providers/fab/provider.yaml index bff1a9cea31f9..95dbced12960f 100644 --- a/providers/fab/provider.yaml +++ b/providers/fab/provider.yaml @@ -252,3 +252,6 @@ config: auth-managers: - airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager + +cli: + - airflow.providers.fab.cli.definition.get_fab_cli_commands diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 397f174fb6dbb..c5486bcdc19e1 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -17,13 +17,11 @@ # under the License. from __future__ import annotations -import argparse from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.parse import urljoin -import packaging.version from connexion import FlaskApi from fastapi import FastAPI from flask import Blueprint, current_app, g @@ -32,7 +30,6 @@ from sqlalchemy.orm import Session, joinedload from starlette.middleware.wsgi import WSGIMiddleware -from airflow import __version__ as airflow_version from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager @@ -52,21 +49,10 @@ VariableDetails, ) from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem -from airflow.cli.cli_config import ( - DefaultHelpParser, - GroupCommand, -) from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.models import Connection, DagModel, Pool, Variable from airflow.providers.common.compat.sdk import AirflowException -from airflow.providers.fab.auth_manager.cli_commands.definition import ( - DB_COMMANDS, - PERMISSIONS_CLEANUP_COMMAND, - ROLES_COMMANDS, - SYNC_PERM_COMMAND, - USERS_COMMANDS, -) from airflow.providers.fab.auth_manager.models import Permission, Role, User from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser from airflow.providers.fab.version_compat import AIRFLOW_V_3_1_PLUS @@ -202,26 +188,9 @@ def apiserver_endpoint(self) -> str: @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" - commands: list[CLICommand] = [ - GroupCommand( - name="users", - help="Manage users", - subcommands=USERS_COMMANDS, - ), - GroupCommand( - name="roles", - help="Manage roles", - subcommands=ROLES_COMMANDS, - ), - SYNC_PERM_COMMAND, # not in a command group - PERMISSIONS_CLEANUP_COMMAND, # single command for permissions cleanup - ] - # If Airflow version is 3.0.0 or higher, add the fab-db command group - if packaging.version.parse( - packaging.version.parse(airflow_version).base_version - ) >= packaging.version.parse("3.0.0"): - commands.append(GroupCommand(name="fab-db", help="Manage FAB", subcommands=DB_COMMANDS)) - return commands + from airflow.providers.fab.cli.definition import get_fab_cli_commands + + return get_fab_cli_commands() def get_fastapi_app(self) -> FastAPI | None: """Get the FastAPI app.""" @@ -760,14 +729,3 @@ def _sync_appbuilder_roles(self): # delete the old ones. if conf.getboolean("fab", "UPDATE_FAB_PERMS"): self.security_manager.sync_roles() - - -def get_parser() -> argparse.ArgumentParser: - """Generate documentation; used by Sphinx argparse.""" - from airflow.cli.cli_parser import AirflowHelpFormatter, _add_command - - parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) - subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") - for group_command in FabAuthManager.get_cli_commands(): - _add_command(subparsers, group_command) - return parser diff --git a/providers/fab/src/airflow/providers/fab/cli/__init__.py b/providers/fab/src/airflow/providers/fab/cli/__init__.py new file mode 100644 index 0000000000000..03cb33c14c40e --- /dev/null +++ b/providers/fab/src/airflow/providers/fab/cli/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/definition.py b/providers/fab/src/airflow/providers/fab/cli/definition.py similarity index 89% rename from providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/definition.py rename to providers/fab/src/airflow/providers/fab/cli/definition.py index eab5e5eedb449..2ac4fa8e0d665 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/definition.py +++ b/providers/fab/src/airflow/providers/fab/cli/definition.py @@ -17,6 +17,7 @@ from __future__ import annotations import textwrap +from typing import TYPE_CHECKING from airflow.cli.cli_config import ( ARG_DB_FROM_REVISION, @@ -35,6 +36,9 @@ lazy_load_command, ) +if TYPE_CHECKING: + import argparse + ############ # # ARGS # # ############ @@ -336,3 +340,47 @@ args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE), ), ) + + +def get_fab_cli_commands(): + """Return CLI commands for FAB auth manager.""" + import packaging.version + + from airflow import __version__ as airflow_version + from airflow.cli.cli_config import GroupCommand + + commands = [ + GroupCommand( + name="users", + help="Manage users", + subcommands=USERS_COMMANDS, + ), + GroupCommand( + name="roles", + help="Manage roles", + subcommands=ROLES_COMMANDS, + ), + SYNC_PERM_COMMAND, # not in a command group + PERMISSIONS_CLEANUP_COMMAND, # single command for permissions cleanup + ] + # If Airflow version is 3.0.0 or higher, add the fab-db command group + if packaging.version.parse( + packaging.version.parse(airflow_version).base_version + ) >= packaging.version.parse("3.0.0"): + commands.append(GroupCommand(name="fab-db", help="Manage FAB", subcommands=DB_COMMANDS)) + return commands + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx argparse. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_fab_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/fab/src/airflow/providers/fab/get_provider_info.py b/providers/fab/src/airflow/providers/fab/get_provider_info.py index 068111260c6ec..3ca0fa9a66202 100644 --- a/providers/fab/src/airflow/providers/fab/get_provider_info.py +++ b/providers/fab/src/airflow/providers/fab/get_provider_info.py @@ -181,4 +181,5 @@ def get_provider_info(): } }, "auth-managers": ["airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager"], + "cli": ["airflow.providers.fab.cli.definition.get_fab_cli_commands"], } diff --git a/providers/fab/tests/unit/fab/auth_manager/cli/__init__.py b/providers/fab/tests/unit/fab/auth_manager/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/fab/tests/unit/fab/auth_manager/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_db_command.py b/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_db_command.py index a17953b110d47..c5ab264384f29 100644 --- a/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_db_command.py +++ b/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_db_command.py @@ -24,6 +24,7 @@ from airflow.cli import cli_parser from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS pytestmark = [pytest.mark.db_test] try: @@ -31,21 +32,22 @@ from airflow.providers.fab.auth_manager.models.db import FABDBManager class TestFABCLiDB: - @classmethod - def setup_class(cls): - with conf_vars( - { - ( - "core", - "auth_manager", - ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", - } - ): - # Reload the module to use FAB auth manager + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: reload(cli_parser) - # Clearing the cache before calling it - cli_parser.get_parser.cache_clear() - cls.parser = cli_parser.get_parser() + self.parser = cli_parser.get_parser() + else: + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + } + ): + reload(cli_parser) + self.parser = cli_parser.get_parser() @mock.patch.object(FABDBManager, "resetdb") def test_cli_resetdb(self, mock_resetdb): diff --git a/providers/fab/tests/unit/fab/cli/__init__.py b/providers/fab/tests/unit/fab/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/fab/tests/unit/fab/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_definition.py b/providers/fab/tests/unit/fab/cli/test_definition.py similarity index 94% rename from providers/fab/tests/unit/fab/auth_manager/cli_commands/test_definition.py rename to providers/fab/tests/unit/fab/cli/test_definition.py index b34f92875ba12..d6904683eab2b 100644 --- a/providers/fab/tests/unit/fab/auth_manager/cli_commands/test_definition.py +++ b/providers/fab/tests/unit/fab/cli/test_definition.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.providers.fab.auth_manager.cli_commands.definition import ( +from airflow.providers.fab.cli.definition import ( ROLES_COMMANDS, SYNC_PERM_COMMAND, USERS_COMMANDS, diff --git a/providers/keycloak/docs/auth-manager/manage/permissions.rst b/providers/keycloak/docs/auth-manager/manage/permissions.rst index ba18719e215c0..7e72cfddf0991 100644 --- a/providers/keycloak/docs/auth-manager/manage/permissions.rst +++ b/providers/keycloak/docs/auth-manager/manage/permissions.rst @@ -45,7 +45,7 @@ They also take the following optional parameters: * ``--dry-run``: If set, the command will check the connection to Keycloak and print the actions that would be performed, without actually executing them. -Please check the `Keycloak auth manager CLI `_ documentation for more information about accepted parameters. +Please check the `Keycloak auth manager CLI `_ documentation for more information about accepted parameters. One-go creation of permissions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/providers/keycloak/docs/cli-refs.rst b/providers/keycloak/docs/cli-ref.rst similarity index 87% rename from providers/keycloak/docs/cli-refs.rst rename to providers/keycloak/docs/cli-ref.rst index eef6311221f82..903d5e4cdddaa 100644 --- a/providers/keycloak/docs/cli-refs.rst +++ b/providers/keycloak/docs/cli-ref.rst @@ -15,12 +15,12 @@ specific language governing permissions and limitations under the License. -Command Line Interface (CLI) -============================ +Keycloak Command Line Interface +=============================== Provider CLI has been integrated with Apache Airflow CLI ``airflow`` command. .. argparse:: - :module: airflow.providers.keycloak.auth_manager.keycloak_auth_manager + :module: airflow.providers.keycloak.cli.definition :func: get_parser :prog: airflow diff --git a/providers/keycloak/docs/index.rst b/providers/keycloak/docs/index.rst index f01c9b8d5ea58..9ab43c12fac87 100644 --- a/providers/keycloak/docs/index.rst +++ b/providers/keycloak/docs/index.rst @@ -44,7 +44,7 @@ Python API <_api/airflow/providers/keycloak/index> Configuration Keycloak auth manager token API - CLI + CLI .. toctree:: :hidden: diff --git a/providers/keycloak/provider.yaml b/providers/keycloak/provider.yaml index 1d3a940d08f25..3c965a8ac02cb 100644 --- a/providers/keycloak/provider.yaml +++ b/providers/keycloak/provider.yaml @@ -35,6 +35,9 @@ versions: auth-managers: - airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager +cli: + - airflow.providers.keycloak.cli.definition.get_keycloak_cli_commands + config: keycloak_auth_manager: description: This section contains settings for Keycloak auth manager integration. diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/__init__.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/__init__.py index 13a83393a9124..21d298ede6ed3 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/__init__.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/__init__.py @@ -14,3 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index 9ad07ffbabb33..8a1e4989a36db 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import argparse import json import logging import time @@ -39,9 +38,8 @@ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod as ExtendedResourceMethod from airflow.api_fastapi.common.types import MenuItem -from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand +from airflow.cli.cli_config import CLICommand from airflow.providers.common.compat.sdk import AirflowException, conf -from airflow.providers.keycloak.auth_manager.cli.definition import KEYCLOAK_AUTH_MANAGER_COMMANDS from airflow.providers.keycloak.auth_manager.constants import ( CONF_CLIENT_ID_KEY, CONF_CLIENT_SECRET_KEY, @@ -69,23 +67,13 @@ PoolDetails, VariableDetails, ) + from airflow.cli.cli_config import CLICommand log = logging.getLogger(__name__) RESOURCE_ID_ATTRIBUTE_NAME = "resource_id" -def get_parser() -> argparse.ArgumentParser: - """Generate documentation; used by Sphinx argparse.""" - from airflow.cli.cli_parser import AirflowHelpFormatter, _add_command - - parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) - subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") - for group_command in KeycloakAuthManager.get_cli_commands(): - _add_command(subparsers, group_command) - return parser - - class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]): """ Keycloak auth manager. @@ -307,13 +295,9 @@ def get_fastapi_app(self) -> FastAPI | None: @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" - return [ - GroupCommand( - name="keycloak-auth-manager", - help="Manage resources used by Keycloak auth manager", - subcommands=KEYCLOAK_AUTH_MANAGER_COMMANDS, - ), - ] + from airflow.providers.keycloak.cli.definition import get_keycloak_cli_commands + + return get_keycloak_cli_commands() @staticmethod def get_keycloak_client() -> KeycloakOpenID: diff --git a/providers/keycloak/src/airflow/providers/keycloak/cli/__init__.py b/providers/keycloak/src/airflow/providers/keycloak/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/keycloak/src/airflow/providers/keycloak/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/definition.py b/providers/keycloak/src/airflow/providers/keycloak/cli/definition.py similarity index 78% rename from providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/definition.py rename to providers/keycloak/src/airflow/providers/keycloak/cli/definition.py index f736a830b2594..218f272c07915 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/cli/definition.py +++ b/providers/keycloak/src/airflow/providers/keycloak/cli/definition.py @@ -18,7 +18,6 @@ from __future__ import annotations import argparse -import getpass from airflow.cli.cli_config import ( ActionCommand, @@ -32,6 +31,8 @@ class Password(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): if values is None: + import getpass + values = getpass.getpass(prompt="Password: ") setattr(namespace, self.dest, values) @@ -97,3 +98,31 @@ def __call__(self, parser, namespace, values, option_string=None): args=(ARG_USERNAME, ARG_PASSWORD, ARG_USER_REALM, ARG_CLIENT_ID, ARG_DRY_RUN), ), ) + + +def get_keycloak_cli_commands(): + """Return CLI commands for Keycloak auth manager.""" + from airflow.cli.cli_config import GroupCommand + + return [ + GroupCommand( + name="keycloak-auth-manager", + help="Manage resources used by Keycloak auth manager", + subcommands=KEYCLOAK_AUTH_MANAGER_COMMANDS, + ), + ] + + +def get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx argparse. + + :meta private: + """ + from airflow.cli.cli_parser import AirflowHelpFormatter, DefaultHelpParser, _add_command + + parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter) + subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND") + for group_command in get_keycloak_cli_commands(): + _add_command(subparsers, group_command) + return parser diff --git a/providers/keycloak/src/airflow/providers/keycloak/get_provider_info.py b/providers/keycloak/src/airflow/providers/keycloak/get_provider_info.py index eb9e941731d1f..9913b03ba7a32 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/get_provider_info.py +++ b/providers/keycloak/src/airflow/providers/keycloak/get_provider_info.py @@ -29,6 +29,7 @@ def get_provider_info(): "auth-managers": [ "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager" ], + "cli": ["airflow.providers.keycloak.cli.definition.get_keycloak_cli_commands"], "config": { "keycloak_auth_manager": { "description": "This section contains settings for Keycloak auth manager integration.", diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_commands.py b/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_commands.py index f5c244e539462..3532da1937f86 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_commands.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_commands.py @@ -34,22 +34,27 @@ from airflow.providers.keycloak.auth_manager.resources import KeycloakResource from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS @pytest.mark.db_test class TestCommands: - @classmethod - def setup_class(cls): - with conf_vars( - { - ( - "core", - "auth_manager", - ): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager", - } - ): + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: importlib.reload(cli_parser) - cls.arg_parser = cli_parser.get_parser() + self.arg_parser = cli_parser.get_parser() + else: + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager", + } + ): + importlib.reload(cli_parser) + self.arg_parser = cli_parser.get_parser() @patch("airflow.providers.keycloak.auth_manager.cli.commands._get_client") def test_create_scopes(self, mock_get_client): diff --git a/providers/keycloak/tests/unit/keycloak/cli/__init__.py b/providers/keycloak/tests/unit/keycloak/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/keycloak/tests/unit/keycloak/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_definition.py b/providers/keycloak/tests/unit/keycloak/cli/test_definition.py similarity index 82% rename from providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_definition.py rename to providers/keycloak/tests/unit/keycloak/cli/test_definition.py index 398766f313118..d10a727f69460 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/cli/test_definition.py +++ b/providers/keycloak/tests/unit/keycloak/cli/test_definition.py @@ -23,24 +23,29 @@ import pytest from airflow.cli import cli_parser -from airflow.providers.keycloak.auth_manager.cli.definition import KEYCLOAK_AUTH_MANAGER_COMMANDS, Password +from airflow.providers.keycloak.cli.definition import KEYCLOAK_AUTH_MANAGER_COMMANDS, Password from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS class TestKeycloakCliDefinition: - @classmethod - def setup_class(cls): - with conf_vars( - { - ( - "core", - "auth_manager", - ): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager", - } - ): + @pytest.fixture(autouse=True) + def setup_parser(self): + if AIRFLOW_V_3_2_PLUS: importlib.reload(cli_parser) - cls.arg_parser = cli_parser.get_parser() + self.arg_parser = cli_parser.get_parser() + else: + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager", + } + ): + importlib.reload(cli_parser) + self.arg_parser = cli_parser.get_parser() def test_keycloak_auth_manager_cli_commands(self): assert len(KEYCLOAK_AUTH_MANAGER_COMMANDS) == 4 diff --git a/scripts/ci/prek/check_cli_definition_imports.py b/scripts/ci/prek/check_cli_definition_imports.py new file mode 100755 index 0000000000000..ba14b8e2f717b --- /dev/null +++ b/scripts/ci/prek/check_cli_definition_imports.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10,<3.11" +# dependencies = [ +# "rich>=13.6.0", +# ] +# /// +""" +Check that CLI definition files only import from allowed modules. + +CLI definition files (matching pattern */cli/definition.py) should only import +from 'airflow.configuration' or 'airflow.cli.cli_config' to avoid circular imports +and ensure clean separation of concerns. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_prek_utils is imported +from common_prek_utils import console, get_imports_from_file + +# Allowed modules that can be imported in CLI definition files +ALLOWED_MODULES = { + "airflow.configuration", + "airflow.cli.cli_config", +} + +# Standard library and __future__ modules are also allowed +STDLIB_PREFIXES = ( + "argparse", + "getpass", + "textwrap", + "typing", + "collections", + "functools", + "itertools", + "pathlib", + "os", + "sys", + "re", + "json", + "dataclasses", + "enum", +) + + +def get_provider_path_from_file(file_path: Path) -> str | None: + """ + Extract the provider path from a CLI definition file. + + For example: + - providers/celery/src/airflow/providers/celery/cli/definition.py -> celery + - providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/definition.py -> cncf.kubernetes + """ + path_str = file_path.as_posix() + + # Find the substring between "airflow/providers/" and "/cli/definition" + start_marker = "airflow/providers/" + end_marker = "/cli/definition" + + start_idx = path_str.find(start_marker) + end_idx = path_str.find(end_marker) + + if start_idx == -1 or end_idx == -1 or start_idx >= end_idx: + return None + + # Extract the provider path and replace '/' with '.' + provider_path = path_str[start_idx + len(start_marker) : end_idx] + return provider_path.replace("/", ".") + + +def is_allowed_import(import_name: str, file_path: Path) -> bool: + """Check if an import is allowed in CLI definition files.""" + # Check if it's one of the allowed Airflow modules + for allowed_module in ALLOWED_MODULES: + if import_name == allowed_module or import_name.startswith(f"{allowed_module}."): + return True + + # Check if it's a standard library module + for prefix in STDLIB_PREFIXES: + if import_name == prefix or import_name.startswith(f"{prefix}."): + return True + + # Allow imports from the provider's own version_compat module + provider_path = get_provider_path_from_file(file_path) + if provider_path: + version_compat_module = f"airflow.providers.{provider_path}.version_compat" + if import_name == version_compat_module or import_name.startswith(f"{version_compat_module}."): + return True + + return False + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Check that CLI definition files only import from allowed modules." + ) + parser.add_argument("files", nargs="*", type=Path, help="Python source files to check.") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + if not args.files: + console.print("[yellow]No files provided.[/]") + return 0 + + errors: list[str] = [] + + # Filter files to only check */cli/definition.py files + cli_definition_files = [ + path + for path in args.files + if path.name == "definition.py" and len(path.parts) >= 2 and path.parts[-2] == "cli" + ] + + if not cli_definition_files: + console.print("[yellow]No CLI definition files found to check.[/]") + return 0 + + console.print(f"[blue]Checking {len(cli_definition_files)} CLI definition file(s)...[/]") + console.print(cli_definition_files) + + for path in cli_definition_files: + try: + imports = get_imports_from_file(path, only_top_level=True) + except Exception as e: + console.print(f"[red]Failed to parse {path}: {e}[/]") + return 2 + + forbidden_imports = [] + for imp in imports: + if not is_allowed_import(imp, path): + forbidden_imports.append(imp) + + if forbidden_imports: + errors.append(f"\n[red]{path}:[/]") + for imp in forbidden_imports: + errors.append(f" - {imp}") + + if errors: + console.print("\n[red] Some CLI definition files contain forbidden imports![/]\n") + console.print( + f"[yellow]CLI definition files (*/cli/definition.py) should only import from:[/]\n" + " - airflow.configuration\n" + " - airflow.cli.cli_config\n" + " - Their own provider's version_compat module\n" + f" - Standard library modules ({', '.join(STDLIB_PREFIXES)})\n" + ) + console.print("[red]Found forbidden imports in:[/]") + for error in errors: + console.print(error) + console.print( + "\n[yellow]This restriction exists to:[/]\n" + " - Keep CLI definitions lightweight and declarative to avoid slowdowns\n" + " - Ensure clean separation between CLI structure and implementation\n" + ) + return 1 + + console.print("[green] All CLI definition files import only from allowed modules![/]") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/in_container/benchmark_cli_latency.py b/scripts/in_container/benchmark_cli_latency.py new file mode 100755 index 0000000000000..da2754396545e --- /dev/null +++ b/scripts/in_container/benchmark_cli_latency.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Benchmark script to measure CLI latency for different Auth Manager and Executor combinations. + +This script: +1. Discovers all available Auth Managers and Executors from providers +2. Tests each combination by running 'airflow --help' +3. Measures response time for each combination +4. Generates a markdown report with results +""" + +from __future__ import annotations + +import os +import subprocess +import sys +import time +from pathlib import Path + +# Add airflow to path +AIRFLOW_SOURCES_DIR = Path(__file__).resolve().parents[3] / "airflow-core" / "src" +sys.path.insert(0, str(AIRFLOW_SOURCES_DIR)) + + +def get_available_auth_managers() -> list[str]: + """Get all available auth manager class names from providers.""" + from airflow.providers_manager import ProvidersManager + + pm = ProvidersManager() + return pm.auth_managers + + +def get_available_executors() -> list[str]: + """Get all available executor class names from providers.""" + from airflow.providers_manager import ProvidersManager + + pm = ProvidersManager() + # Get executors from providers + executor_names = pm.executor_class_names + + # Add core executors + core_executors = [ + "airflow.executors.local_executor.LocalExecutor", + "airflow.executors.sequential_executor.SequentialExecutor", + ] + + all_executors = list(set(core_executors + executor_names)) + return sorted(all_executors) + + +def measure_cli_latency( + auth_manager: str | None, executor: str | None, runs: int = 3 +) -> tuple[float, float, bool]: + """ + Measure the latency of 'airflow --help' command. + + Args: + auth_manager: Auth manager class name (None for default) + executor: Executor class name (None for default) + runs: Number of runs to average + + Returns: + Tuple of (average_time, min_time, success) + """ + env = os.environ.copy() + + if auth_manager: + env["AIRFLOW__CORE__AUTH_MANAGER"] = auth_manager + if executor: + env["AIRFLOW__CORE__EXECUTOR"] = executor + + times = [] + success = True + + for _ in range(runs): + start = time.time() + try: + result = subprocess.run( + ["airflow", "--help"], + env=env, + capture_output=True, + timeout=30, + check=False, + ) + elapsed = time.time() - start + + # Check if command succeeded + if result.returncode != 0: + success = False + break + + times.append(elapsed) + except (subprocess.TimeoutExpired, Exception) as e: + print(f"Error running command: {e}", file=sys.stderr) + success = False + break + + if not times: + return 0.0, 0.0, False + + avg_time = sum(times) / len(times) + min_time = min(times) + + return avg_time, min_time, success + + +def format_class_name(class_name: str | None) -> str: + """Format class name for display (show only last part).""" + if class_name is None: + return "Default" + parts = class_name.split(".") + if len(parts) > 1: + return parts[-1] + return class_name + + +def generate_markdown_report(results: list[dict]) -> str: + """Generate markdown formatted report.""" + lines = [ + "# Airflow CLI Latency Benchmark", + "", + "Benchmark results for `airflow --help` command with different Auth Manager and Executor combinations.", + "", + f"Total combinations tested: {len(results)}", + "", + "## Results Table", + "", + "| Auth Manager | Executor | Avg Time (s) | Min Time (s) | Status |", + "|--------------|----------|--------------|--------------|--------|", + ] + + for result in results: + auth_display = format_class_name(result["auth_manager"]) + executor_display = format_class_name(result["executor"]) + avg_time = f"{result['avg_time']:.3f}" if result["success"] else "N/A" + min_time = f"{result['min_time']:.3f}" if result["success"] else "N/A" + status = "✅" if result["success"] else "❌" + + lines.append(f"| {auth_display} | {executor_display} | {avg_time} | {min_time} | {status} |") + + lines.extend( + [ + "", + "## Summary Statistics", + "", + ] + ) + + successful_results = [r for r in results if r["success"]] + if successful_results: + avg_times = [r["avg_time"] for r in successful_results] + lines.extend( + [ + f"- **Successful combinations**: {len(successful_results)}/{len(results)}", + f"- **Overall average time**: {sum(avg_times) / len(avg_times):.3f}s", + f"- **Fastest time**: {min(avg_times):.3f}s", + f"- **Slowest time**: {max(avg_times):.3f}s", + ] + ) + else: + lines.append("- No successful combinations") + + lines.extend( + [ + "", + "---", + "", + "*Note: Each combination was run 3 times and averaged.*", + ] + ) + + return "\n".join(lines) + + +def main(): + """Main function to run the benchmark.""" + print("=" * 80) + print("Airflow CLI Latency Benchmark") + print("=" * 80) + print() + + print("Discovering available Auth Managers and Executors...") + + try: + auth_managers = get_available_auth_managers() + executors = get_available_executors() + except Exception as e: + print(f"Error discovering providers: {e}", file=sys.stderr) + return 1 + + print(f"Found {len(auth_managers)} Auth Managers") + print(f"Found {len(executors)} Executors") + print() + + # Add None to test default configuration + auth_managers_to_test = [None] + auth_managers + executors_to_test = [None] + executors + + total_combinations = len(auth_managers_to_test) * len(executors_to_test) + print(f"Testing {total_combinations} combinations...") + print() + + results = [] + count = 0 + + for auth_manager in auth_managers_to_test: + for executor in executors_to_test: + count += 1 + auth_display = format_class_name(auth_manager) + executor_display = format_class_name(executor) + + print( + f"[{count}/{total_combinations}] Testing: {auth_display} + {executor_display}...", + end=" ", + flush=True, + ) + + avg_time, min_time, success = measure_cli_latency(auth_manager, executor) + + results.append( + { + "auth_manager": auth_manager, + "executor": executor, + "avg_time": avg_time, + "min_time": min_time, + "success": success, + } + ) + + if success: + print(f"✅ {avg_time:.3f}s (avg) / {min_time:.3f}s (min)") + else: + print("❌ Failed") + + print() + print("=" * 80) + print("Generating report...") + print("=" * 80) + print() + + report = generate_markdown_report(results) + print(report) + + # Optionally save to file + output_file = Path("cli_latency_benchmark.md") + output_file.write_text(report) + print() + print(f"Report saved to: {output_file.absolute()}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main())