diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 8021a82ca9ff1..610c6d6cac990 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -27,7 +27,6 @@ from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.models import DagModel from airflow.typing_compat import Literal from airflow.utils.jwt_signer import JWTSigner, get_signing_key @@ -38,8 +37,6 @@ from collections.abc import Container, Sequence from fastapi import FastAPI - from flask import Blueprint - from flask_appbuilder.menu import MenuItem from sqlalchemy.orm import Session from airflow.auth.managers.models.batch_apis import ( @@ -81,22 +78,6 @@ def init(self) -> None: By default, do nothing. """ - def get_user_name(self) -> str: - """Return the username associated to the user in session.""" - user = self.get_user() - if not user: - self.log.error("Calling 'get_user_name()' but the user is not signed in.") - raise AirflowException("The user must be signed in.") - return user.get_name() - - def get_user_display_name(self) -> str: - """Return the user's display name associated to the user in session.""" - return self.get_user_name() - - @abstractmethod - def get_user(self) -> T | None: - """Return the user associated to the user in session.""" - @abstractmethod def deserialize_user(self, token: dict[str, Any]) -> T: """Create a user object from dict.""" @@ -122,36 +103,10 @@ def get_jwt_token( expiration_time_in_seconds=expiration_time_in_seconds ).generate_signed_token(self.serialize_user(user)) - def get_user_id(self) -> str | None: - """Return the user ID associated to the user in session.""" - user = self.get_user() - if not user: - self.log.error("Calling 'get_user_id()' but the user is not signed in.") - raise AirflowException("The user must be signed in.") - if user_id := user.get_id(): - return str(user_id) - return None - - @abstractmethod - def is_logged_in(self) -> bool: - """Return whether the user is logged in.""" - @abstractmethod def get_url_login(self, **kwargs) -> str: """Return the login page url.""" - @abstractmethod - def get_url_logout(self) -> str: - """Return the logout page url.""" - - def get_url_user_profile(self) -> str | None: - """ - Return the url to a page displaying info about the current user. - - By default, return None. - """ - return None - @abstractmethod def is_authorized_configuration( self, @@ -282,14 +237,6 @@ def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_na :param user: the user to performing the action """ - @abstractmethod - def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: - """ - Filter menu items based on user permissions. - - :param menu_items: list of all menu items - """ - def batch_is_authorized_connection( self, requests: Sequence[IsAuthorizedConnectionRequest], @@ -444,11 +391,6 @@ def get_cli_commands() -> list[CLICommand]: """ return [] - def get_api_endpoints(self) -> None | Blueprint: - """Return API endpoint(s) definition for the auth manager.""" - # TODO: Remove this method when legacy Airflow 2 UI is gone - return None - def get_fastapi_app(self) -> FastAPI | None: """ Specify a sub FastAPI application specific to the auth manager. @@ -457,9 +399,6 @@ def get_fastapi_app(self) -> FastAPI | None: """ return None - def register_views(self) -> None: - """Register views specific to the auth manager.""" - @staticmethod def _get_token_signer( expiration_time_in_seconds: int = conf.getint("api", "auth_jwt_expiration_time"), diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 7887786633c30..7f37efae7ea43 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -26,7 +26,6 @@ from typing import TYPE_CHECKING, Any from fastapi import FastAPI -from flask import session from starlette.requests import Request from starlette.responses import HTMLResponse from starlette.staticfiles import StaticFiles @@ -39,8 +38,6 @@ from airflow.settings import AIRFLOW_PATH if TYPE_CHECKING: - from flask_appbuilder.menu import MenuItem - from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, @@ -127,27 +124,10 @@ def init(self) -> None: with open(self.get_generated_password_file(), "w") as file: file.write(json.dumps(passwords)) - def is_logged_in(self) -> bool: - # Remove this method when legacy UI is removed - return "user" in session or conf.getboolean("core", "simple_auth_manager_all_admins") - def get_url_login(self, **kwargs) -> str: """Return the login page url.""" return "/auth/webapp/login" - def get_url_logout(self) -> str: - # Remove this method when legacy UI is removed - raise NotImplementedError() - - def get_user(self) -> SimpleAuthManagerUser | None: - # Remove this method when legacy UI is removed - if not self.is_logged_in(): - return None - if conf.getboolean("core", "simple_auth_manager_all_admins"): - return SimpleAuthManagerUser(username="anonymous", role="admin") - else: - return session["user"] - def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser: return SimpleAuthManagerUser(username=token["username"], role=token["role"]) @@ -232,9 +212,6 @@ def is_authorized_custom_view( ): return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user) - def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: - return menu_items - def get_fastapi_app(self) -> FastAPI | None: """ Specify a sub FastAPI application specific to the auth manager. diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index bd5e8572c6ffc..533e95f5fd000 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -153,7 +153,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): pytest.param( ("airflow/api/file.py",), { - "selected-providers-list-as-string": "amazon common.compat databricks edge fab", + "selected-providers-list-as-string": "common.compat databricks edge fab", "all-python-versions": "['3.9']", "all-python-versions-list-as-string": "3.9", "python-versions": "['3.9']", @@ -162,14 +162,13 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "prod-image-build": "false", "needs-helm-tests": "false", "run-tests": "true", - "run-amazon-tests": "true", "docs-build": "true", "skip-pre-commits": "check-provider-yaml-valid,identity,lint-helm-chart,mypy-airflow,mypy-dev," "mypy-docs,mypy-providers,mypy-task-sdk,ts-compile-format-lint-ui", "upgrade-to-newer-dependencies": "false", "core-test-types-list-as-string": "API Always", - "providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,databricks,edge,fab]", - "individual-providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat] Providers[databricks] Providers[edge] Providers[fab]", + "providers-test-types-list-as-string": "Providers[common.compat,databricks,edge,fab]", + "individual-providers-test-types-list-as-string": "Providers[common.compat] Providers[databricks] Providers[edge] Providers[fab]", "testable-core-integrations": "['celery', 'kerberos']", "testable-providers-integrations": "['cassandra', 'drill', 'kafka', 'mongo', 'pinot', 'qdrant', 'redis', 'trino', 'ydb']", "needs-mypy": "true", @@ -722,7 +721,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("providers/amazon/src/airflow/providers/amazon/__init__.py",), { "selected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.compat common.messaging common.sql exasol fab ftp google http imap microsoft.azure " + "common.compat common.messaging common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.9']", "all-python-versions-list-as-string": "3.9", @@ -739,7 +738,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "run-amazon-tests": "true", "core-test-types-list-as-string": "Always", - "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.hive,cncf.kubernetes,common.compat,common.messaging,common.sql,exasol,fab,ftp,http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", + "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.hive,cncf.kubernetes,common.compat,common.messaging,common.sql,exasol,ftp,http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-checks": "['mypy-providers']", }, @@ -774,7 +773,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("providers/amazon/src/airflow/providers/amazon/file.py",), { "selected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.compat common.messaging common.sql exasol fab ftp google http imap microsoft.azure " + "common.compat common.messaging common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.9']", "all-python-versions-list-as-string": "3.9", @@ -791,7 +790,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "false", "upgrade-to-newer-dependencies": "false", "core-test-types-list-as-string": "Always", - "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.hive,cncf.kubernetes,common.compat,common.messaging,common.sql,exasol,fab,ftp,http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", + "providers-test-types-list-as-string": "Providers[amazon] Providers[apache.hive,cncf.kubernetes,common.compat,common.messaging,common.sql,exasol,ftp,http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-checks": "['mypy-providers']", }, diff --git a/docs/apache-airflow/core-concepts/auth-manager/index.rst b/docs/apache-airflow/core-concepts/auth-manager/index.rst index b61b44ae39ec4..dbc0f279ea6b4 100644 --- a/docs/apache-airflow/core-concepts/auth-manager/index.rst +++ b/docs/apache-airflow/core-concepts/auth-manager/index.rst @@ -96,10 +96,8 @@ Some reasons you may want to write a custom auth manager include: Authentication related BaseAuthManager methods ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* ``is_logged_in``: Return whether the user is signed-in. * ``get_user``: Return the signed-in user. * ``get_url_login``: Return the URL the user is redirected to for signing in. -* ``get_url_logout``: Return the URL the user is redirected to for signing out. Authorization related BaseAuthManager methods ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -142,7 +140,6 @@ The following methods aren't required to override to have a functional Airflow a * ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it will call ``is_authorized_pool`` for every single item. * ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it will call ``is_authorized_variable`` for every single item. * ``get_permitted_dag_ids``: Return the list of DAG IDs the user has access to. If not overridden, it will call ``is_authorized_dag`` for every single DAG available in the environment. -* ``filter_permitted_menu_items``: Return the menu items the user has access to. If not overridden, it will call ``has_access`` in :class:`~airflow.www.security_manager.AirflowSecurityManagerV2` for every single menu item. CLI ^^^ @@ -176,10 +173,13 @@ Auth managers may vend CLI commands which will be included in the ``airflow`` co .. note:: When creating a new auth manager, or updating any existing auth manager, be sure to not import or execute any expensive operations/code at the module level. Auth manager classes are imported in several places and if they are slow to import this will negatively impact the performance of your Airflow environment, especially for CLI commands. -Rest API -^^^^^^^^ +Extending API server application +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``get_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager. +Auth managers have the option to extend the Airflow API server. Doing so, allow, for instance, to vend additional public API endpoints. +To extend the API server application, you need to implement the ``get_fastapi_app`` method. +Such additional endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. +Endpoints defined by ``get_fastapi_app`` are mounted in ``/auth``. Next Steps ^^^^^^^^^^ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a5f403ba44424..e1fc18617cb49 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -59,7 +59,6 @@ "common.compat", "common.sql", "exasol", - "fab", "ftp", "google", "http", diff --git a/newsfragments/aip-79.significant.rst b/newsfragments/aip-79.significant.rst index 97d26bee0f10c..497c2054b1435 100644 --- a/newsfragments/aip-79.significant.rst +++ b/newsfragments/aip-79.significant.rst @@ -12,7 +12,17 @@ As part of this change the following breaking changes have occurred: - The property ``security_manager`` has been removed from the interface - - The method ``filter_permitted_menu_items`` is now abstract and must be implemented + - All these methods have been removed from the interface: + + - ``filter_permitted_menu_items`` + - ``get_user_name`` + - ``get_user_display_name`` + - ``get_user`` + - ``get_user_id`` + - ``is_logged_in`` + - ``get_url_logout`` + - ``get_api_endpoints`` + - ``register_views`` - All the following method signatures changed to make the parameter ``user`` required (it was optional) diff --git a/providers/amazon/README.rst b/providers/amazon/README.rst index f053738fa492a..0fec71dbae179 100644 --- a/providers/amazon/README.rst +++ b/providers/amazon/README.rst @@ -90,7 +90,6 @@ Dependent package `apache-airflow-providers-common-compat `_ ``common.compat`` `apache-airflow-providers-common-sql `_ ``common.sql`` `apache-airflow-providers-exasol `_ ``exasol`` -`apache-airflow-providers-fab `_ ``fab`` `apache-airflow-providers-ftp `_ ``ftp`` `apache-airflow-providers-google `_ ``google`` `apache-airflow-providers-http `_ ``http`` diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 455d03fba4725..dfcb0e1c46878 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, cast from fastapi import FastAPI -from flask import session from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( @@ -49,8 +48,6 @@ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: - from flask_appbuilder.menu import MenuItem - from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, @@ -86,12 +83,6 @@ def avp_facade(self): def apiserver_endpoint(self) -> str: return conf.get("api", "base_url") - def get_user(self) -> AwsAuthManagerUser | None: - return session["aws_user"] if self.is_logged_in() else None - - def is_logged_in(self) -> bool: - return "aws_user" in session - def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser: return AwsAuthManagerUser(**token) @@ -331,57 +322,9 @@ def _has_access_to_dag(request: IsAuthorizedRequest): ) } - def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: - """ - Filter menu items based on user permissions. - - :param menu_items: list of all menu items - """ - user = self.get_user() - if not user: - return [] - - requests: dict[str, IsAuthorizedRequest] = {} - for menu_item in menu_items: - if menu_item.childs: - for child in menu_item.childs: - requests[child.name] = self._get_menu_item_request(child.name) - else: - requests[menu_item.name] = self._get_menu_item_request(menu_item.name) - - batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results( - requests=list(requests.values()), user=user - ) - - def _has_access_to_menu_item(request: IsAuthorizedRequest): - result = self.avp_facade.get_batch_is_authorized_single_result( - batch_is_authorized_results=batch_is_authorized_results, request=request, user=user - ) - return result["decision"] == "ALLOW" - - accessible_items = [] - for menu_item in menu_items: - if menu_item.childs: - accessible_children = [] - for child in menu_item.childs: - if _has_access_to_menu_item(requests[child.name]): - accessible_children.append(child) - menu_item.childs = accessible_children - - # Display the menu if the user has access to at least one sub item - if len(accessible_children) > 0: - accessible_items.append(menu_item) - elif _has_access_to_menu_item(requests[menu_item.name]): - accessible_items.append(menu_item) - - return accessible_items - def get_url_login(self, **kwargs) -> str: return f"{self.apiserver_endpoint}/auth/login" - def get_url_logout(self) -> str: - raise NotImplementedError() - @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" @@ -408,14 +351,6 @@ def get_fastapi_app(self) -> FastAPI | None: return app - @staticmethod - def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest: - return { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": resource_name, - } - def _check_avp_schema_version(self): if not self.avp_facade.is_policy_store_schema_up_to_date(): self.log.warning( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py deleted file mode 100644 index 1e4421e784904..0000000000000 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - try: - from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 - except ImportError: - raise AirflowOptionalProviderFeatureException( - "Failed to import AirflowSecurityManagerV2 from the FAB provider. The AWS auth manager requires the FAB provider." - ) -else: - from airflow.www.security_manager import AirflowSecurityManagerV2 - - -class AwsSecurityManagerOverride(AirflowSecurityManagerV2): - """ - The security manager override specific to AWS auth manager. - - This class is only used in Airflow 2. This can be safely be removed when min Airflow version >= 3 - """ - - def register_views(self): - """Register views specific to AWS auth manager.""" - from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews - - self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews()) diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py index 44be563aa7635..a1f213c2b7355 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -20,7 +20,6 @@ from unittest.mock import ANY, Mock, patch import pytest -from flask_appbuilder.menu import MenuItem from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS @@ -39,22 +38,14 @@ from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser -from airflow.security.permissions import ( - RESOURCE_AUDIT_LOG, - RESOURCE_CLUSTER_ACTIVITY, - RESOURCE_CONNECTION, - RESOURCE_VARIABLE, -) from tests_common.test_utils.config import conf_vars if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import AssetDetails - from airflow.security.permissions import RESOURCE_ASSET else: from airflow.providers.common.compat.assets import AssetDetails - from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET mock = Mock() @@ -100,13 +91,6 @@ class TestAwsAuthManager: def test_avp_facade(self, auth_manager): assert hasattr(auth_manager, "avp_facade") - @patch.object(AwsAuthManager, "is_logged_in") - def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_manager): - mock_is_logged_in.return_value = False - result = auth_manager.get_user() - - assert result is None - @pytest.mark.parametrize( "details, user, expected_user, expected_entity_id", [ @@ -460,127 +444,6 @@ def test_batch_is_authorized_variable( ) assert result - @patch.object(AwsAuthManager, "get_user") - def test_filter_permitted_menu_items(self, mock_get_user, auth_manager, test_user): - batch_is_authorized_output = [ - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "Connections"}, - }, - "decision": "DENY", - }, - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "Variables"}, - }, - "decision": "ALLOW", - }, - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": RESOURCE_ASSET}, - }, - "decision": "DENY", - }, - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "Cluster Activity"}, - }, - "decision": "DENY", - }, - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "Audit Logs"}, - }, - "decision": "ALLOW", - }, - { - "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, - "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "CustomPage"}, - }, - "decision": "ALLOW", - }, - ] - auth_manager.avp_facade.get_batch_is_authorized_results = Mock( - return_value=batch_is_authorized_output - ) - - mock_get_user.return_value = test_user - - result = auth_manager.filter_permitted_menu_items( - [ - MenuItem("Category1", childs=[MenuItem(RESOURCE_CONNECTION), MenuItem(RESOURCE_VARIABLE)]), - MenuItem("Category2", childs=[MenuItem(RESOURCE_ASSET)]), - MenuItem(RESOURCE_CLUSTER_ACTIVITY), - MenuItem(RESOURCE_AUDIT_LOG), - MenuItem("CustomPage"), - ] - ) - - """ - return { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": resource_name, - } - """ - - auth_manager.avp_facade.get_batch_is_authorized_results.assert_called_once_with( - requests=[ - { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": "Connections", - }, - { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": "Variables", - }, - { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": RESOURCE_ASSET, - }, - {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Cluster Activity"}, - {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Audit Logs"}, - { - "method": "MENU", - "entity_type": AvpEntities.MENU, - "entity_id": "CustomPage", - }, - ], - user=test_user, - ) - assert len(result) == 3 - assert result[0].name == "Category1" - assert len(result[0].childs) == 1 - assert result[0].childs[0].name == RESOURCE_VARIABLE - assert result[1].name == RESOURCE_AUDIT_LOG - assert result[2].name == "CustomPage" - - @patch.object(AwsAuthManager, "get_user") - def test_filter_permitted_menu_items_logged_out(self, mock_get_user, auth_manager): - mock_get_user.return_value = None - result = auth_manager.filter_permitted_menu_items( - [ - MenuItem(RESOURCE_AUDIT_LOG), - ] - ) - - assert result == [] - @pytest.mark.parametrize( "methods, user", [ diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 660ea82e5a8ca..3181fd8bdef4e 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -27,7 +27,6 @@ from connexion import FlaskApi from fastapi import FastAPI from flask import Blueprint, g, url_for -from flask_appbuilder.menu import MenuItem from packaging.version import Version from sqlalchemy import select from sqlalchemy.orm import Session, joinedload @@ -64,7 +63,6 @@ from airflow.providers.fab.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security.permissions import ( - ACTION_CAN_ACCESS_MENU, RESOURCE_AUDIT_LOG, RESOURCE_CLUSTER_ACTIVITY, RESOURCE_CONFIG, @@ -214,21 +212,13 @@ def get_api_endpoints(self) -> None | Blueprint: return FlaskApi( specification=specification, resolver=_LazyResolver(), - # TODO: change to "/fab/v1" when legacy UI is gone - base_path="/auth/fab/v1", + base_path="/fab/v1", options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, strict_validation=True, validate_responses=True, validator_map={"body": _CustomErrorRequestBodyValidator}, ).blueprint - def get_user_display_name(self) -> str: - """Return the user's display name associated to the user in session.""" - user = self.get_user() - first_name = user.first_name.strip() if isinstance(user.first_name, str) else "" - last_name = user.last_name.strip() if isinstance(user.last_name, str) else "" - return f"{first_name} {last_name}".strip() - def get_user(self) -> User: """ Return the user associated to the user in session. @@ -406,32 +396,6 @@ def get_permitted_dag_ids( resources.add(resource) return set(session.scalars(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources)))) - def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: - """ - Filter menu items based on user permissions. - - :param menu_items: list of all menu items - """ - items = filter( - lambda item: self.security_manager.has_access(ACTION_CAN_ACCESS_MENU, item.name), menu_items - ) - accessible_items = [] - for menu_item in items: - menu_item_copy = MenuItem( - **{ - **menu_item.__dict__, - "childs": [], - } - ) - if menu_item.childs: - accessible_children = [] - for child in menu_item.childs: - if self.security_manager.has_access(ACTION_CAN_ACCESS_MENU, child.name): - accessible_children.append(child) - menu_item_copy.childs = accessible_children - accessible_items.append(menu_item_copy) - return accessible_items - @cached_property def security_manager(self) -> FabAirflowSecurityManagerOverride: """Return the security manager specific to FAB.""" @@ -470,14 +434,6 @@ def get_url_logout(self): raise AirflowException("`auth_view` not defined in the security manager.") return url_for(f"{self.security_manager.auth_view.endpoint}.logout") - def get_url_user_profile(self) -> str | None: - """Return the url to a page displaying info about the current user.""" - if not self.security_manager.user_view or ( - self.appbuilder and self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) - ): - return None - return url_for(f"{self.security_manager.user_view.endpoint}.userinfo") - def register_views(self) -> None: self.security_manager.register_views() diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index f5045e7c2b294..adbdfe14397d9 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -113,6 +113,7 @@ AirflowDatabaseSessionInterface, AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface, ) +from airflow.providers.fab.www.utils import get_fab_auth_manager if TYPE_CHECKING: from airflow.providers.fab.www.security.permissions import RESOURCE_ASSET @@ -2461,7 +2462,7 @@ def auth_user_remote_user(self, username): return user def get_user_menu_access(self, menu_names: list[str] | None = None) -> set[str]: - if get_auth_manager().is_logged_in(): + if get_fab_auth_manager().is_logged_in(): return self._get_user_permission_resources(g.user, "menu_access", resource_names=menu_names) elif current_user_jwt: return self._get_user_permission_resources( diff --git a/providers/fab/src/airflow/providers/fab/www/auth.py b/providers/fab/src/airflow/providers/fab/www/auth.py index 40c5dc51f0c8b..1e45a03fc23c9 100644 --- a/providers/fab/src/airflow/providers/fab/www/auth.py +++ b/providers/fab/src/airflow/providers/fab/www/auth.py @@ -40,6 +40,7 @@ VariableDetails, ) from airflow.configuration import conf +from airflow.providers.fab.www.utils import get_fab_auth_manager from airflow.utils.net import get_hostname if TYPE_CHECKING: @@ -138,19 +139,19 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): """ if is_authorized: return func(*args, **kwargs) - elif get_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( + elif get_fab_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( access_view=AccessView.WEBSITE, - user=get_auth_manager().get_user(), + user=get_fab_auth_manager().get_user(), ): return ( render_template( "airflow/no_roles_permissions.html", hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "", - logout_url=get_auth_manager().get_url_logout(), + logout_url=get_fab_auth_manager().get_url_logout(), ), 403, ) - elif not get_auth_manager().is_logged_in(): + elif not get_fab_auth_manager().is_logged_in(): return redirect(get_auth_manager().get_url_login(next_url=request.url)) else: access_denied = get_access_denied_message() @@ -161,7 +162,7 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): def has_access_configuration(method: ResourceMethod) -> Callable[[T], T]: return _has_access_no_details( lambda: get_auth_manager().is_authorized_configuration( - method=method, user=get_auth_manager().get_user() + method=method, user=get_fab_auth_manager().get_user() ) ) @@ -276,7 +277,7 @@ def decorated(*args, **kwargs): def has_access_asset(method: ResourceMethod) -> Callable[[T], T]: """Check current user's permissions against required permissions for assets.""" return _has_access_no_details( - lambda: get_auth_manager().is_authorized_asset(method=method, user=get_auth_manager().get_user()) + lambda: get_auth_manager().is_authorized_asset(method=method, user=get_fab_auth_manager().get_user()) ) @@ -344,6 +345,6 @@ def has_access_view(access_view: AccessView = AccessView.WEBSITE) -> Callable[[T """Check current user's permissions to access the website.""" return _has_access_no_details( lambda: get_auth_manager().is_authorized_view( - access_view=access_view, user=get_auth_manager().get_user() + access_view=access_view, user=get_fab_auth_manager().get_user() ) ) diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py index 912e1533ed6fe..bb4e338f76cbf 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -306,7 +306,9 @@ def _add_admin_views(self): self.indexview = self._check_and_init(self.indexview) self.add_view_no_menu(self.indexview) self.add_view_no_menu(UtilView()) - get_auth_manager().register_views() + auth_manager = get_auth_manager() + if hasattr(auth_manager, "register_views"): + auth_manager.register_views() def _add_addon_views(self): """Register declared addons.""" diff --git a/providers/fab/src/airflow/providers/fab/www/utils.py b/providers/fab/src/airflow/providers/fab/www/utils.py index 6ddf6265788a9..b14da6448642e 100644 --- a/providers/fab/src/airflow/providers/fab/www/utils.py +++ b/providers/fab/src/airflow/providers/fab/www/utils.py @@ -27,11 +27,25 @@ from sqlalchemy import types from sqlalchemy.ext.associationproxy import AssociationProxy +from airflow.api_fastapi.app import get_auth_manager from airflow.utils import timezone if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager + + +def get_fab_auth_manager() -> FabAuthManager: + from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager + + auth_manager = get_auth_manager() + if not isinstance(auth_manager, FabAuthManager): + raise RuntimeError( + "This functionality is only available with if FabAuthManager is configured as auth manager in the environment." + ) + return auth_manager + class UtcAwareFilterMixin: """Mixin for filter for UTC time.""" diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py index 4814941ba154d..1f71295a5b548 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py +++ b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py @@ -71,7 +71,7 @@ def test_success(self): clear_db_pools() with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": token}) + response = test_client.get("/fab/v1/users", headers={"Authorization": token}) assert current_user.email == "test@fab.org" assert response.status_code == 200 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index 915fb36e28e78..54f7b9fe94ecf 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -80,14 +80,12 @@ def teardown_method(self): class TestGetRoleEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json["name"] == "Admin" def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/roles/invalid-role", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/fab/v1/roles/invalid-role", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 404 assert response.json == { "detail": "Role with name 'invalid-role' was not found", @@ -97,12 +95,12 @@ def test_should_respond_404(self): } def test_should_raises_401_unauthenticated(self): - response = self.client.get("/auth/fab/v1/roles/Admin") + response = self.client.get("/fab/v1/roles/Admin") assert_401(response) def test_should_raise_403_forbidden(self): response = self.client.get( - "/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -112,13 +110,13 @@ def test_should_raise_403_forbidden(self): indirect=["set_auth_role_public"], ) def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): - response = self.client.get("/auth/fab/v1/roles/Admin") + response = self.client.get("/fab/v1/roles/Admin") assert response.status_code == expected_status_code, response.json class TestGetRolesEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/roles", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) @@ -127,21 +125,19 @@ def test_should_response_200(self): assert roles == existing_roles def test_should_raises_401_unauthenticated(self): - response = self.client.get("/auth/fab/v1/roles") + response = self.client.get("/fab/v1/roles") assert_401(response) def test_should_raises_400_for_invalid_order_by(self): response = self.client.get( - "/auth/fab/v1/roles?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} + "/fab/v1/roles?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" assert response.json["detail"] == msg def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/fab/v1/roles", environ_overrides={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -150,7 +146,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auth_role_public"], ) def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): - response = self.client.get("/auth/fab/v1/roles") + response = self.client.get("/fab/v1/roles") assert response.status_code == expected_status_code, response.json @@ -158,20 +154,20 @@ class TestGetRolesEndpointPaginationandFilter(TestRoleEndpoint): @pytest.mark.parametrize( "url, expected_roles", [ - ("/auth/fab/v1/roles?limit=1", ["Admin"]), - ("/auth/fab/v1/roles?limit=2", ["Admin", "Op"]), + ("/fab/v1/roles?limit=1", ["Admin"]), + ("/fab/v1/roles?limit=2", ["Admin", "Op"]), ( - "/auth/fab/v1/roles?offset=1", + "/fab/v1/roles?offset=1", ["Op", "Public", "Test", "TestNoPermissions", "User", "Viewer"], ), ( - "/auth/fab/v1/roles?offset=0", + "/fab/v1/roles?offset=0", ["Admin", "Op", "Public", "Test", "TestNoPermissions", "User", "Viewer"], ), - ("/auth/fab/v1/roles?limit=1&offset=2", ["Public"]), - ("/auth/fab/v1/roles?limit=1&offset=1", ["Op"]), + ("/fab/v1/roles?limit=1&offset=2", ["Public"]), + ("/fab/v1/roles?limit=1&offset=1", ["Op"]), ( - "/auth/fab/v1/roles?limit=2&offset=2", + "/fab/v1/roles?limit=2&offset=2", ["Public", "Test"], ), ], @@ -189,7 +185,7 @@ def test_can_handle_limit_and_offset(self, url, expected_roles): class TestGetPermissionsEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) actions = {i[0] for i in self.app.appbuilder.sm.get_all_permissions() if i} assert response.status_code == 200 assert response.json["total_entries"] == len(actions) @@ -197,12 +193,12 @@ def test_should_response_200(self): assert actions == returned_actions def test_should_raises_401_unauthenticated(self): - response = self.client.get("/auth/fab/v1/permissions") + response = self.client.get("/fab/v1/permissions") assert_401(response) def test_should_raise_403_forbidden(self): response = self.client.get( - "/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -212,7 +208,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auth_role_public"], ) def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): - response = self.client.get("/auth/fab/v1/permissions") + response = self.client.get("/fab/v1/permissions") assert response.status_code == expected_status_code, response.json @@ -222,9 +218,7 @@ def test_post_should_respond_200(self): "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 role = self.app.appbuilder.sm.find_role("Test2") assert role is not None @@ -295,9 +289,7 @@ def test_post_should_respond_200(self): ], ) def test_post_should_respond_400_for_invalid_payload(self, payload, error_message): - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 400 assert response.json == { "detail": error_message, @@ -311,9 +303,7 @@ def test_post_should_respond_409_already_exist(self): "name": "Test", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 409 assert response.json == { "detail": "Role with name 'Test' already exists; please update with the PATCH endpoint", @@ -324,7 +314,7 @@ def test_post_should_respond_409_already_exist(self): def test_should_raises_401_unauthenticated(self): response = self.client.post( - "/auth/fab/v1/roles", + "/fab/v1/roles", json={ "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], @@ -335,7 +325,7 @@ def test_should_raises_401_unauthenticated(self): def test_should_raise_403_forbidden(self): response = self.client.post( - "/auth/fab/v1/roles", + "/fab/v1/roles", json={ "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], @@ -354,23 +344,21 @@ def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_c "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post("/auth/fab/v1/roles", json=payload) + response = self.client.post("/fab/v1/roles", json=payload) assert response.status_code == expected_status_code, response.json class TestDeleteRole(TestRoleEndpoint): def test_delete_should_respond_204(self, session): role = create_role(self.app, "mytestrole") - response = self.client.delete( - f"/auth/fab/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete(f"/fab/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 204 role_obj = session.query(Role).filter(Role.name == role.name).all() assert len(role_obj) == 0 def test_delete_should_respond_404(self): response = self.client.delete( - "/auth/fab/v1/roles/invalidrolename", environ_overrides={"REMOTE_USER": "test"} + "/fab/v1/roles/invalidrolename", environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 404 assert response.json == { @@ -381,13 +369,13 @@ def test_delete_should_respond_404(self): } def test_should_raises_401_unauthenticated(self): - response = self.client.delete("/auth/fab/v1/roles/test") + response = self.client.delete("/fab/v1/roles/test") assert_401(response) def test_should_raise_403_forbidden(self): response = self.client.delete( - "/auth/fab/v1/roles/test", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/fab/v1/roles/test", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -398,7 +386,7 @@ def test_should_raise_403_forbidden(self): ) def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): role = create_role(self.app, "mytestrole") - response = self.client.delete(f"/auth/fab/v1/roles/{role.name}") + response = self.client.delete(f"/fab/v1/roles/{role.name}") assert response.status_code == expected_status_code, response.location @@ -420,7 +408,7 @@ class TestPatchRole(TestRoleEndpoint): def test_patch_should_respond_200(self, payload, expected_name, expected_actions): role = create_role(self.app, "mytestrole") response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}", json=payload, environ_overrides={"REMOTE_USER": "test"} + f"/fab/v1/roles/{role.name}", json=payload, environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 200 assert response.json["name"] == expected_name @@ -431,7 +419,7 @@ def test_patch_should_update_correct_roles_permissions(self): create_role(self.app, "already_exists") response = self.client.patch( - "/auth/fab/v1/roles/role_to_change", + "/fab/v1/roles/role_to_change", json={ "name": "already_exists", "actions": [{"action": {"name": "can_delete"}, "resource": {"name": "XComs"}}], @@ -476,7 +464,7 @@ def test_patch_should_respond_200_with_update_mask( role = create_role(self.app, "mytestrole") assert role.permissions == [] response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}{update_mask}", + f"/fab/v1/roles/{role.name}{update_mask}", json=payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -488,7 +476,7 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): role = create_role(self.app, "mytestrole") payload = {"name": "testme"} response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}?update_mask=invalid_name", + f"/fab/v1/roles/{role.name}?update_mask=invalid_name", json=payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -548,7 +536,7 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): def test_patch_should_respond_400_for_invalid_update(self, payload, expected_error): role = create_role(self.app, "mytestrole") response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}", + f"/fab/v1/roles/{role.name}", json=payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -557,7 +545,7 @@ def test_patch_should_respond_400_for_invalid_update(self, payload, expected_err def test_should_raises_401_unauthenticated(self): response = self.client.patch( - "/auth/fab/v1/roles/test", + "/fab/v1/roles/test", json={ "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], @@ -568,7 +556,7 @@ def test_should_raises_401_unauthenticated(self): def test_should_raise_403_forbidden(self): response = self.client.patch( - "/auth/fab/v1/roles/test", + "/fab/v1/roles/test", json={ "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], @@ -585,7 +573,7 @@ def test_should_raise_403_forbidden(self): def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): role = create_role(self.app, "mytestrole") response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}", + f"/fab/v1/roles/{role.name}", json={"name": "mytest"}, ) assert response.status_code == expected_status_code, response.json diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_user_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_user_endpoint.py index faba6ea87b76f..5132003229634 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -106,7 +106,7 @@ def test_should_respond_200(self): users = self._create_users(1) self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json == { "active": True, @@ -134,7 +134,7 @@ def test_last_names_can_be_empty(self): ) self.session.add_all([prince]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/prince", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users/prince", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json == { "active": True, @@ -162,7 +162,7 @@ def test_first_names_can_be_empty(self): ) self.session.add_all([liberace]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/liberace", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users/liberace", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json == { "active": True, @@ -190,7 +190,7 @@ def test_both_first_and_last_names_can_be_empty(self): ) self.session.add_all([nameless]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/nameless", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users/nameless", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json == { "active": True, @@ -207,9 +207,7 @@ def test_both_first_and_last_names_can_be_empty(self): } def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/users/invalid-user", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/fab/v1/users/invalid-user", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 404 assert response.json == { "detail": "The User with username `invalid-user` was not found", @@ -219,32 +217,30 @@ def test_should_respond_404(self): } def test_should_raises_401_unauthenticated(self): - response = self.client.get("/auth/fab/v1/users/TEST_USER1") + response = self.client.get("/fab/v1/users/TEST_USER1") assert_401(response) def test_should_raise_403_forbidden(self): response = self.client.get( - "/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 class TestGetUsers(TestUserEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert response.json["total_entries"] == 2 usernames = [user["username"] for user in response.json["users"] if user] assert usernames == ["test", "test_no_permissions"] def test_should_raises_401_unauthenticated(self): - response = self.client.get("/auth/fab/v1/users") + response = self.client.get("/fab/v1/users") assert_401(response) def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/fab/v1/users", environ_overrides={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -252,10 +248,10 @@ class TestGetUsersPagination(TestUserEndpoint): @pytest.mark.parametrize( "url, expected_usernames", [ - ("/auth/fab/v1/users?limit=1", ["test"]), - ("/auth/fab/v1/users?limit=2", ["test", "test_no_permissions"]), + ("/fab/v1/users?limit=1", ["test"]), + ("/fab/v1/users?limit=2", ["test", "test_no_permissions"]), ( - "/auth/fab/v1/users?offset=5", + "/fab/v1/users?offset=5", [ "TEST_USER4", "TEST_USER5", @@ -267,7 +263,7 @@ class TestGetUsersPagination(TestUserEndpoint): ], ), ( - "/auth/fab/v1/users?offset=0", + "/fab/v1/users?offset=0", [ "test", "test_no_permissions", @@ -283,10 +279,10 @@ class TestGetUsersPagination(TestUserEndpoint): "TEST_USER10", ], ), - ("/auth/fab/v1/users?limit=1&offset=5", ["TEST_USER4"]), - ("/auth/fab/v1/users?limit=1&offset=1", ["test_no_permissions"]), + ("/fab/v1/users?limit=1&offset=5", ["TEST_USER4"]), + ("/fab/v1/users?limit=1&offset=1", ["test_no_permissions"]), ( - "/auth/fab/v1/users?limit=2&offset=2", + "/fab/v1/users?limit=2&offset=2", ["TEST_USER1", "TEST_USER2"], ), ], @@ -306,7 +302,7 @@ def test_should_respect_page_size_limit_default(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicitly add the 2 users on setUp assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) @@ -316,9 +312,7 @@ def test_should_response_400_with_invalid_order_by(self): users = self._create_users(2) self.session.add_all(users) self.session.commit() - response = self.client.get( - "/auth/fab/v1/users?order_by=myname", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/fab/v1/users?order_by=myname", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'myname' is disallowed or the attribute does not exist on the model" assert response.json["detail"] == msg @@ -328,7 +322,7 @@ def test_limit_of_zero_should_return_default(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users?limit=0", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicit add the 2 users on setUp assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) @@ -340,7 +334,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/fab/v1/users?limit=180", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 assert len(response.json["users"]) == 150 @@ -433,7 +427,7 @@ class TestPostUser(TestUserEndpoint): def test_with_default_role(self, autoclean_username, autoclean_user_payload): self.client.application.config["AUTH_USER_REGISTRATION_ROLE"] = "Public" response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -446,7 +440,7 @@ def test_with_default_role(self, autoclean_username, autoclean_user_payload): def test_with_custom_roles(self, autoclean_username, autoclean_user_payload): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, environ_overrides={"REMOTE_USER": "test"}, ) @@ -460,7 +454,7 @@ def test_with_custom_roles(self, autoclean_username, autoclean_user_payload): @pytest.mark.usefixtures("user_different") def test_with_existing_different_user(self, autoclean_user_payload): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, environ_overrides={"REMOTE_USER": "test"}, ) @@ -468,14 +462,14 @@ def test_with_existing_different_user(self, autoclean_user_payload): def test_unauthenticated(self, autoclean_user_payload): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=autoclean_user_payload, ) assert response.status_code == 401, response.json def test_forbidden(self, autoclean_user_payload): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) @@ -499,7 +493,7 @@ def test_already_exists( existing = request.getfixturevalue(existing_user_fixture_name) response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -535,7 +529,7 @@ def test_already_exists( ) def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_message): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=payload_converter(autoclean_user_payload), environ_overrides={"REMOTE_USER": "test"}, ) @@ -550,7 +544,7 @@ def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_ def test_internal_server_error(self, autoclean_user_payload): with unittest.mock.patch.object(self.app.appbuilder.sm, "add_user", return_value=None): response = self.client.post( - "/auth/fab/v1/users", + "/fab/v1/users", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -567,7 +561,7 @@ class TestPatchUser(TestUserEndpoint): def test_change(self, autoclean_username, autoclean_user_payload): autoclean_user_payload["first_name"] = "Changed" response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -583,7 +577,7 @@ def test_change_with_update_mask(self, autoclean_username, autoclean_user_payloa autoclean_user_payload["first_name"] = "Changed" autoclean_user_payload["last_name"] = "McTesterson" response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}?update_mask=last_name", + f"/fab/v1/users/{autoclean_username}?update_mask=last_name", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -613,7 +607,7 @@ def test_patch_already_exists( ): autoclean_user_payload.update(payload) response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -634,7 +628,7 @@ def test_required_fields( ): autoclean_user_payload.pop(field) response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -646,7 +640,7 @@ def test_username_can_be_updated(self, autoclean_user_payload, autoclean_usernam testusername = "testusername" autoclean_user_payload.update({"username": testusername}) response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -666,7 +660,7 @@ def test_password_hashed( ): autoclean_user_payload["password"] = "new-pass" response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -685,7 +679,7 @@ def test_replace_roles(self, autoclean_username, autoclean_user_payload): # Patching a user's roles should replace the entire list. autoclean_user_payload["roles"] = [{"name": "User"}, {"name": "Viewer"}] response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}?update_mask=roles", + f"/fab/v1/users/{autoclean_username}?update_mask=roles", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -696,7 +690,7 @@ def test_replace_roles(self, autoclean_username, autoclean_user_payload): def test_unchanged(self, autoclean_username, autoclean_user_payload): # Should allow a PATCH that changes nothing. response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -708,7 +702,7 @@ def test_unchanged(self, autoclean_username, autoclean_user_payload): @pytest.mark.usefixtures("autoclean_admin_user") def test_unauthenticated(self, autoclean_username, autoclean_user_payload): response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, ) assert response.status_code == 401, response.json @@ -716,7 +710,7 @@ def test_unauthenticated(self, autoclean_username, autoclean_user_payload): @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username, autoclean_user_payload): response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) @@ -725,7 +719,7 @@ def test_forbidden(self, autoclean_username, autoclean_user_payload): def test_not_found(self, autoclean_username, autoclean_user_payload): # This test does not populate autoclean_admin_user into the database. response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, environ_overrides={"REMOTE_USER": "test"}, ) @@ -765,7 +759,7 @@ def test_invalid_payload( error_message, ): response = self.client.patch( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", json=payload_converter(autoclean_user_payload), environ_overrides={"REMOTE_USER": "test"}, ) @@ -782,7 +776,7 @@ class TestDeleteUser(TestUserEndpoint): @pytest.mark.usefixtures("autoclean_admin_user") def test_delete(self, autoclean_username): response = self.client.delete( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 204, response.json # NO CONTENT. @@ -791,7 +785,7 @@ def test_delete(self, autoclean_username): @pytest.mark.usefixtures("autoclean_admin_user") def test_unauthenticated(self, autoclean_username): response = self.client.delete( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", ) assert response.status_code == 401, response.json assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1 @@ -799,7 +793,7 @@ def test_unauthenticated(self, autoclean_username): @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username): response = self.client.delete( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403, response.json @@ -808,7 +802,7 @@ def test_forbidden(self, autoclean_username): def test_not_found(self, autoclean_username): # This test does not populate autoclean_admin_user into the database. response = self.client.delete( - f"/auth/fab/v1/users/{autoclean_username}", + f"/fab/v1/users/{autoclean_username}", environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 404, response.json diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index 526fc7441d3b6..399937e63a515 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -20,11 +20,10 @@ from itertools import chain from typing import TYPE_CHECKING from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from flask import Flask, g -from flask_appbuilder.menu import Menu from airflow.exceptions import AirflowConfigException, AirflowException from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder @@ -39,7 +38,6 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager - from airflow.providers.fab.auth_manager.models import User from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET @@ -102,29 +100,6 @@ def auth_manager_with_appbuilder(flask_app): @pytest.mark.db_test class TestFabAuthManager: - @pytest.mark.parametrize( - "id,first_name,last_name,username,email,expected", - [ - (1, "First", "Last", None, None, "First Last"), - (1, "First", None, "user", None, "First"), - (1, None, "Last", "user", "email", "Last"), - (1, None, None, None, "email", ""), - ], - ) - @mock.patch.object(FabAuthManager, "get_user") - def test_get_user_display_name( - self, mock_get_user, id, first_name, last_name, username, email, expected, auth_manager - ): - user = User() - user.id = id - user.first_name = first_name - user.last_name = last_name - user.username = username - user.email = email - mock_get_user.return_value = user - - assert auth_manager.get_user_display_name() == expected - @mock.patch("flask_login.utils._get_user") def test_get_user(self, mock_current_user, minimal_app_for_auth_api, auth_manager): user = Mock() @@ -474,79 +449,6 @@ def test_is_authorized_custom_view( result = auth_manager.is_authorized_custom_view(method=method, resource_name=resource_name, user=user) assert result == expected_result - @patch.object(FabAuthManager, "security_manager") - def test_filter_permitted_menu_items(self, mock_security_manager, auth_manager): - mock_security_manager.has_access.side_effect = [True, False, True, True, False] - - menu = Menu() - menu.add_link( - # These may not all be valid types, but it does let us check each attr is copied - name="item1", - href="h1", - icon="i1", - label="l1", - baseview="b1", - cond="c1", - ) - menu.add_link("item2") - menu.add_link("item3") - menu.add_link("item3.1", category="item3") - menu.add_link("item3.2", category="item3") - - result = auth_manager.filter_permitted_menu_items(menu.get_list()) - - assert len(result) == 2 - assert result[0].name == "item1" - assert result[1].name == "item3" - assert len(result[1].childs) == 1 - assert result[1].childs[0].name == "item3.1" - # check we've copied every attr - assert result[0].href == "h1" - assert result[0].icon == "i1" - assert result[0].label == "l1" - assert result[0].baseview == "b1" - assert result[0].cond == "c1" - - @patch.object(FabAuthManager, "security_manager") - def test_filter_permitted_menu_items_twice(self, mock_security_manager, auth_manager): - mock_security_manager.has_access.side_effect = [ - # 1st call - True, # menu 1 - False, # menu 2 - True, # menu 3 - True, # Item 3.1 - False, # Item 3.2 - # 2nd call - False, # menu 1 - True, # menu 2 - True, # menu 3 - False, # Item 3.1 - True, # Item 3.2 - ] - - menu = Menu() - menu.add_link("item1") - menu.add_link("item2") - menu.add_link("item3") - menu.add_link("item3.1", category="item3") - menu.add_link("item3.2", category="item3") - - result = auth_manager.filter_permitted_menu_items(menu.get_list()) - - assert len(result) == 2 - assert result[0].name == "item1" - assert result[1].name == "item3" - assert len(result[1].childs) == 1 - assert result[1].childs[0].name == "item3.1" - - result = auth_manager.filter_permitted_menu_items(menu.get_list()) - - assert len(result) == 2 - assert result[0].name == "item2" - assert result[1].name == "item3" - assert len(result[1].childs) == 1 - assert result[1].childs[0].name == "item3.2" - @pytest.mark.db_test def test_security_manager_return_fab_security_manager_override(self, auth_manager_with_appbuilder): assert isinstance(auth_manager_with_appbuilder.security_manager, FabAirflowSecurityManagerOverride) @@ -590,18 +492,3 @@ def test_get_url_logout(self, mock_url_for, auth_manager_with_appbuilder): auth_manager_with_appbuilder.security_manager.auth_view.endpoint = "test_endpoint" auth_manager_with_appbuilder.get_url_logout() mock_url_for.assert_called_once_with("test_endpoint.logout") - - @pytest.mark.db_test - def test_get_url_user_profile_when_auth_view_not_defined(self, auth_manager_with_appbuilder): - assert auth_manager_with_appbuilder.get_url_user_profile() is None - - @pytest.mark.db_test - @mock.patch("airflow.providers.fab.auth_manager.fab_auth_manager.url_for") - def test_get_url_user_profile(self, mock_url_for, auth_manager_with_appbuilder): - expected_url = "test_url" - mock_url_for.return_value = expected_url - auth_manager_with_appbuilder.security_manager.user_view = Mock() - auth_manager_with_appbuilder.security_manager.user_view.endpoint = "test_endpoint" - actual_url = auth_manager_with_appbuilder.get_url_user_profile() - mock_url_for.assert_called_once_with("test_endpoint.userinfo") - assert actual_url == expected_url diff --git a/providers/fab/tests/unit/fab/www/test_auth.py b/providers/fab/tests/unit/fab/www/test_auth.py index de69129fa2c5e..cfd31e282be89 100644 --- a/providers/fab/tests/unit/fab/www/test_auth.py +++ b/providers/fab/tests/unit/fab/www/test_auth.py @@ -51,8 +51,9 @@ def method_test(self): return True @patch("airflow.providers.fab.www.auth.get_auth_manager") + @patch("airflow.providers.fab.www.auth.get_fab_auth_manager") def test_has_access_no_details_when_authorized( - self, mock_get_auth_manager, decorator_name, is_authorized_method_name + self, _, mock_get_auth_manager, decorator_name, is_authorized_method_name ): auth_manager = Mock() is_authorized_method = Mock() @@ -66,9 +67,10 @@ def test_has_access_no_details_when_authorized( assert result is True @patch("airflow.providers.fab.www.auth.get_auth_manager") + @patch("airflow.providers.fab.www.auth.get_fab_auth_manager") @patch("airflow.providers.fab.www.auth.render_template") def test_has_access_no_details_when_no_permission( - self, mock_render_template, mock_get_auth_manager, decorator_name, is_authorized_method_name + self, mock_render_template, _, mock_get_auth_manager, decorator_name, is_authorized_method_name ): auth_manager = Mock() is_authorized_method = Mock() @@ -204,7 +206,10 @@ def test_has_access_dag_entities_when_authorized(self, mock_get_auth_manager, da @pytest.mark.db_test @patch("airflow.providers.fab.www.auth.get_auth_manager") - def test_has_access_dag_entities_when_unauthorized(self, mock_get_auth_manager, app, dag_access_entity): + @patch("airflow.providers.fab.www.auth.get_fab_auth_manager") + def test_has_access_dag_entities_when_unauthorized( + self, _, mock_get_auth_manager, app, dag_access_entity + ): auth_manager = Mock() auth_manager.batch_is_authorized_dag.return_value = False mock_get_auth_manager.return_value = auth_manager diff --git a/providers/google/tests/unit/google/common/auth_backend/test_google_openid.py b/providers/google/tests/unit/google/common/auth_backend/test_google_openid.py index 938e63bb9a072..5ec863080f20a 100644 --- a/providers/google/tests/unit/google/common/auth_backend/test_google_openid.py +++ b/providers/google/tests/unit/google/common/auth_backend/test_google_openid.py @@ -88,7 +88,7 @@ def test_success(self, mock_verify_token): } with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) + response = test_client.get("/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) assert response.status_code == 200 @@ -102,7 +102,7 @@ def test_malformed_headers(self, mock_verify_token, auth_header): } with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": auth_header}) + response = test_client.get("/fab/v1/users", headers={"Authorization": auth_header}) assert response.status_code == 401 @@ -115,7 +115,7 @@ def test_invalid_iss_in_jwt_token(self, mock_verify_token): } with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) + response = test_client.get("/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) assert response.status_code == 401 @@ -128,14 +128,14 @@ def test_user_not_exists(self, mock_verify_token): } with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) + response = test_client.get("/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) assert response.status_code == 401 @conf_vars({("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid"}) def test_missing_id_token(self): with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users") + response = test_client.get("/fab/v1/users") assert response.status_code == 401 @@ -145,6 +145,6 @@ def test_invalid_id_token(self, mock_verify_token): mock_verify_token.side_effect = GoogleAuthError("Invalid token") with self.app.test_client() as test_client: - response = test_client.get("/auth/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) + response = test_client.get("/fab/v1/users", headers={"Authorization": "bearer JWT_TOKEN"}) assert response.status_code == 401 diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index aff8150490bad..c228ad9158443 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -29,11 +29,8 @@ PoolDetails, VariableDetails, ) -from airflow.exceptions import AirflowException if TYPE_CHECKING: - from flask_appbuilder.menu import MenuItem - from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, @@ -55,9 +52,6 @@ def get_name(self) -> str: class EmptyAuthManager(BaseAuthManager[BaseAuthManagerUserTest]): - def get_user(self) -> BaseAuthManagerUserTest: - raise NotImplementedError() - def deserialize_user(self, token: dict[str, Any]) -> BaseAuthManagerUserTest: raise NotImplementedError() @@ -129,18 +123,9 @@ def is_authorized_custom_view( ): raise NotImplementedError() - def is_logged_in(self) -> bool: - raise NotImplementedError() - def get_url_login(self, **kwargs) -> str: raise NotImplementedError() - def get_url_logout(self) -> str: - raise NotImplementedError() - - def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: - raise NotImplementedError() - @pytest.fixture def auth_manager(): @@ -151,42 +136,9 @@ class TestBaseAuthManager: def test_get_cli_commands_return_empty_list(self, auth_manager): assert auth_manager.get_cli_commands() == [] - def test_get_api_endpoints_return_none(self, auth_manager): - assert auth_manager.get_api_endpoints() is None - def test_get_fastapi_app_return_none(self, auth_manager): assert auth_manager.get_fastapi_app() is None - def test_get_user_name(self, auth_manager): - user = Mock() - user.get_name.return_value = "test_username" - auth_manager.get_user = MagicMock(return_value=user) - result = auth_manager.get_user_name() - assert result == "test_username" - - def test_get_user_name_when_not_logged_in(self, auth_manager): - auth_manager.get_user = MagicMock(return_value=None) - with pytest.raises(AirflowException): - auth_manager.get_user_name() - - def test_get_user_display_name_return_user_name(self, auth_manager): - auth_manager.get_user_name = MagicMock(return_value="test_user") - assert auth_manager.get_user_display_name() == "test_user" - - def test_get_user_id_return_user_id(self, auth_manager): - user = Mock() - user.get_id = MagicMock(return_value="test_user") - auth_manager.get_user = MagicMock(return_value=user) - assert auth_manager.get_user_id() == "test_user" - - def test_get_user_id_raise_exception_when_no_user(self, auth_manager): - auth_manager.get_user = MagicMock(return_value=None) - with pytest.raises(AirflowException, match="The user must be signed in."): - auth_manager.get_user_id() - - def test_get_url_user_profile_return_none(self, auth_manager): - assert auth_manager.get_url_user_profile() is None - @patch("airflow.auth.managers.base_auth_manager.JWTSigner") @patch.object(EmptyAuthManager, "deserialize_user") def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer, auth_manager):