From 02dc9da5e740ac52eb18fe39a0fb6428fd417884 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Thu, 6 Mar 2025 11:47:55 -0500 Subject: [PATCH] Fix and simplify `get_permitted_dag_ids` in auth manager --- airflow/auth/managers/base_auth_manager.py | 32 +++----- .../aws/auth_manager/aws_auth_manager.py | 34 +++----- .../aws/auth_manager/test_aws_auth_manager.py | 56 ++++++++++--- .../fab/auth_manager/fab_auth_manager.py | 40 ++++------ .../auth_manager/security_manager/override.py | 4 +- .../fab/auth_manager/test_fab_auth_manager.py | 80 ++++++++++++++++++- .../unit/fab/auth_manager/test_security.py | 2 +- tests/auth/managers/test_base_auth_manager.py | 17 +--- 8 files changed, 167 insertions(+), 98 deletions(-) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 0ca18db8121ef..f3b86600a0583 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -34,7 +34,7 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from collections.abc import Container, Sequence + from collections.abc import Sequence from fastapi import FastAPI from sqlalchemy.orm import Session @@ -331,7 +331,7 @@ def get_permitted_dag_ids( self, *, user: T, - methods: Container[ResourceMethod] | None = None, + method: ResourceMethod = "GET", session: Session = NEW_SESSION, ) -> set[str]: """ @@ -342,45 +342,31 @@ def get_permitted_dag_ids( implementation to provide a more efficient implementation. :param user: the user - :param methods: whether filter readable or writable + :param method: the method to filter on :param session: the session """ dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, user=user) + return self.filter_permitted_dag_ids(dag_ids=dag_ids, method=method, user=user) def filter_permitted_dag_ids( self, *, dag_ids: set[str], user: T, - methods: Container[ResourceMethod] | None = None, + method: ResourceMethod = "GET", ) -> set[str]: """ Filter readable or writable DAGs for user. :param dag_ids: the list of DAG ids :param user: the user - :param methods: whether filter readable or writable + :param method: the method to filter on """ - if not methods: - methods = ["PUT", "GET"] - if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( - "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) - ): - # If user is authorized to read/edit all DAGs, return all DAGs - return dag_ids + def _is_permitted_dag_id(method: ResourceMethod, dag_id: str): + return self.is_authorized_dag(method=method, details=DagDetails(id=dag_id), user=user) - def _is_permitted_dag_id(method: ResourceMethod, methods: Container[ResourceMethod], dag_id: str): - return method in methods and self.is_authorized_dag( - method=method, details=DagDetails(id=dag_id), user=user - ) - - return { - dag_id - for dag_id in dag_ids - if _is_permitted_dag_id("GET", methods, dag_id) or _is_permitted_dag_id("PUT", methods, dag_id) - } + return {dag_id for dag_id in dag_ids if _is_permitted_dag_id(method, dag_id)} @staticmethod def get_cli_commands() -> list[CLICommand]: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 6c992438a96d8..04b547b297925 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -18,7 +18,7 @@ import argparse from collections import defaultdict -from collections.abc import Container, Sequence +from collections.abc import Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, cast @@ -283,23 +283,18 @@ def filter_permitted_dag_ids( *, dag_ids: set[str], user: AwsAuthManagerUser, - methods: Container[ResourceMethod] | None = None, + method: ResourceMethod = "GET", ): - if not methods: - methods = ["PUT", "GET"] - requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict) requests_list: list[IsAuthorizedRequest] = [] for dag_id in dag_ids: - for method in ["GET", "PUT"]: - if method in methods: - request: IsAuthorizedRequest = { - "method": cast("ResourceMethod", method), - "entity_type": AvpEntities.DAG, - "entity_id": dag_id, - } - requests[dag_id][cast("ResourceMethod", method)] = request - requests_list.append(request) + request: IsAuthorizedRequest = { + "method": method, + "entity_type": AvpEntities.DAG, + "entity_id": dag_id, + } + requests[dag_id][method] = request + requests_list.append(request) batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results( requests=requests_list, user=user @@ -311,16 +306,7 @@ def _has_access_to_dag(request: IsAuthorizedRequest): ) return result["decision"] == "ALLOW" - return { - dag_id - for dag_id in dag_ids - if ( - "GET" in methods - and _has_access_to_dag(requests[dag_id]["GET"]) - or "PUT" in methods - and _has_access_to_dag(requests[dag_id]["PUT"]) - ) - } + return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])} def get_url_login(self, **kwargs) -> str: return f"{self.apiserver_endpoint}/auth/login" diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py index be3485d40c179..45935e4b68622 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -445,26 +445,62 @@ def test_batch_is_authorized_variable( assert result @pytest.mark.parametrize( - "methods, user", + "method, user, expected_result", [ - (None, AwsAuthManagerUser(user_id="test_user_id", groups=[])), - (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", groups=[])), + ("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), {"dag_1"}), + ("PUT", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), set()), + ("GET", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), set()), + ("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), {"dag_2"}), ], ) - def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): + def test_filter_permitted_dag_ids(self, method, user, auth_manager, test_user, expected_result): dag_ids = {"dag_1", "dag_2"} + # test_user_id1 has GET permissions on dag_1 + # test_user_id2 has PUT permissions on dag_2 batch_is_authorized_output = [ { "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"}, "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, }, + "decision": "ALLOW", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, + }, "decision": "DENY", }, { "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"}, "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, }, @@ -472,7 +508,7 @@ def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): }, { "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"}, "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, }, @@ -480,7 +516,7 @@ def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): }, { "request": { - "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"}, "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, }, @@ -493,12 +529,12 @@ def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user): result = auth_manager.filter_permitted_dag_ids( dag_ids=dag_ids, - methods=methods, + method=method, user=user, ) auth_manager.avp_facade.get_batch_is_authorized_results.assert_called() - assert result == {"dag_2"} + assert result == expected_result def test_get_url_login(self, auth_manager): result = auth_manager.get_url_login() diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index b416d31e2125f..eb3ef724ab666 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -18,7 +18,6 @@ from __future__ import annotations import argparse -from collections.abc import Container from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any @@ -58,6 +57,7 @@ USERS_COMMANDS, ) from airflow.providers.fab.auth_manager.models import Permission, Role, User +from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser from airflow.providers.fab.www.app import create_app from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED from airflow.providers.fab.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver @@ -355,30 +355,24 @@ def get_permitted_dag_ids( self, *, user: User, - methods: Container[ResourceMethod] | None = None, + method: ResourceMethod = "GET", session: Session = NEW_SESSION, ) -> set[str]: - if not methods: - methods = ["PUT", "GET"] - - if not self.is_logged_in(): - roles = user.roles - else: - if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( - "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) - ): - # If user is authorized to read/edit all DAGs, return all DAGs - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - user_query = session.scalar( - select(User) - .options( - joinedload(User.roles) - .subqueryload(Role.permissions) - .options(joinedload(Permission.action), joinedload(Permission.resource)) - ) - .where(User.id == user.id) + if self._is_authorized(method=method, resource_type=RESOURCE_DAG, user=user): + # If user is authorized to access all DAGs, return all DAGs + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + if isinstance(user, AnonymousUser): + return set() + user_query = session.scalar( + select(User) + .options( + joinedload(User.roles) + .subqueryload(Role.permissions) + .options(joinedload(Permission.action), joinedload(Permission.resource)) ) - roles = user_query.roles + .where(User.id == user.id) + ) + roles = user_query.roles map_fab_action_name_to_method_name = get_method_from_fab_action_map() resources = set() @@ -387,7 +381,7 @@ def get_permitted_dag_ids( action = permission.action.name if ( action in map_fab_action_name_to_method_name - and map_fab_action_name_to_method_name[action] in methods + and map_fab_action_name_to_method_name[action] == method ): resource = permission.resource.name if resource == permissions.RESOURCE_DAG: diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index adbdfe14397d9..ef36355d2ad9e 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -973,12 +973,12 @@ def create_db(self): @staticmethod def get_readable_dag_ids(user=None) -> set[str]: """Get the DAG IDs readable by authenticated user.""" - return get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user) + return get_auth_manager().get_permitted_dag_ids(user=user) @staticmethod def get_editable_dag_ids(user=None) -> set[str]: """Get the DAG IDs editable by authenticated user.""" - return get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=user) + return get_auth_manager().get_permitted_dag_ids(method="PUT", user=user) def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: """Check if user has read or write access to some dags.""" diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index 399937e63a515..45f200e09e53c 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -27,7 +27,8 @@ from airflow.exceptions import AirflowConfigException, AirflowException from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user +from airflow.providers.standard.operators.empty import EmptyOperator +from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user try: from airflow.auth.managers.models.resource_details import AccessView, DagAccessEntity, DagDetails @@ -449,6 +450,83 @@ def test_is_authorized_custom_view( result = auth_manager.is_authorized_custom_view(method=method, resource_name=resource_name, user=user) assert result == expected_result + @pytest.mark.parametrize( + "method, user_permissions, expected_results", + [ + # Scenario 1 + # With global read permissions on Dags + ( + "GET", + [(ACTION_CAN_READ, RESOURCE_DAG)], + {"test_dag1", "test_dag2"}, + ), + # Scenario 2 + # With global edit permissions on Dags + ( + "PUT", + [(ACTION_CAN_EDIT, RESOURCE_DAG)], + {"test_dag1", "test_dag2"}, + ), + # Scenario 3 + # With DAG-specific permissions + ( + "GET", + [(ACTION_CAN_READ, "DAG:test_dag1")], + {"test_dag1"}, + ), + # Scenario 4 + # With no permissions + ( + "GET", + [], + set(), + ), + # Scenario 5 + # With read permissions but edit is requested + ( + "PUT", + [(ACTION_CAN_READ, RESOURCE_DAG)], + set(), + ), + # Scenario 7 + # With read permissions but edit is requested + ( + "PUT", + [(ACTION_CAN_READ, "DAG:test_dag1")], + set(), + ), + # Scenario 8 + # With DAG-specific permissions + ( + "PUT", + [(ACTION_CAN_EDIT, "DAG:test_dag1"), (ACTION_CAN_EDIT, "DAG:test_dag2")], + {"test_dag1", "test_dag2"}, + ), + ], + ) + def test_get_permitted_dag_ids( + self, method, user_permissions, expected_results, auth_manager_with_appbuilder, dag_maker, flask_app + ): + with dag_maker("test_dag1"): + EmptyOperator(task_id="task1") + with dag_maker("test_dag2"): + EmptyOperator(task_id="task1") + + auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag1") + auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag2") + + user = create_user( + flask_app, + username="username", + role_name="test", + permissions=user_permissions, + ) + + results = auth_manager_with_appbuilder.get_permitted_dag_ids(user=user, method=method) + assert results == expected_results + + delete_user(flask_app, "username") + @pytest.mark.db_test def test_security_manager_return_fab_security_manager_override(self, auth_manager_with_appbuilder): assert isinstance(auth_manager_with_appbuilder.security_manager, FabAirflowSecurityManagerOverride) diff --git a/providers/fab/tests/unit/fab/auth_manager/test_security.py b/providers/fab/tests/unit/fab/auth_manager/test_security.py index 4ceca5cb3ac23..a2aa51e329165 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_security.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_security.py @@ -544,7 +544,7 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( dag_id, access_control={role_name: permission_action} ) - assert get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user) == set() + assert get_auth_manager().get_permitted_dag_ids(user=user) == set() def test_has_access(security_manager): diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index c228ad9158443..a7bb0322d015e 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -258,34 +258,23 @@ def test_batch_is_authorized_variable( assert result == expected @pytest.mark.parametrize( - "access_all, access_per_dag, dag_ids, expected", + "access_per_dag, dag_ids, expected", [ - # Access to all dags - ( - True, - {}, - ["dag1", "dag2"], - {"dag1", "dag2"}, - ), # No access to any dag ( - False, {}, ["dag1", "dag2"], set(), ), # Access to specific dags ( - False, {"dag1": True}, ["dag1", "dag2"], {"dag1"}, ), ], ) - def test_get_permitted_dag_ids( - self, auth_manager, access_all: bool, access_per_dag: dict, dag_ids: list, expected: set - ): + def test_get_permitted_dag_ids(self, auth_manager, access_per_dag: dict, dag_ids: list, expected: set): def side_effect_func( *, method: ResourceMethod, @@ -294,7 +283,7 @@ def side_effect_func( user: BaseAuthManagerUserTest | None = None, ): if not details: - return access_all + return False else: return access_per_dag.get(details.id, False)