From 6845d8bc67f76d1c285a8b5890c5579e81dcba44 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Mon, 22 Apr 2024 11:22:43 -0400 Subject: [PATCH] Update `is_authorized_custom_view` from auth manager to handle custom actions --- airflow/auth/managers/base_auth_manager.py | 7 +++++-- .../amazon/aws/auth_manager/avp/entities.py | 2 +- .../providers/amazon/aws/auth_manager/avp/facade.py | 7 +++++-- .../amazon/aws/auth_manager/aws_auth_manager.py | 2 +- .../providers/fab/auth_manager/fab_auth_manager.py | 7 +++++-- tests/auth/managers/test_base_auth_manager.py | 2 +- .../fab/auth_manager/test_fab_auth_manager.py | 13 ++++++++++++- 7 files changed, 30 insertions(+), 10 deletions(-) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 4d5c249235a69..7bb4e92889e5c 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -237,7 +237,7 @@ def is_authorized_view( @abstractmethod def is_authorized_custom_view( - self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None ): """ Return whether the user is authorized to perform a given action on a custom view. @@ -246,7 +246,10 @@ def is_authorized_custom_view( the auth manager is used as part of the environment. It can also be a view defined as part of a plugin defined by a user. - :param method: the method to perform + :param method: the method to perform. + The method can also be a string if the action has been defined in a plugin. + In that case, the action can be anything (e.g. can_do). + See https://github.com/apache/airflow/issues/39144 :param resource_name: the name of the resource :param user: the user to perform the action on. If not provided (or None), it uses the current user """ diff --git a/airflow/providers/amazon/aws/auth_manager/avp/entities.py b/airflow/providers/amazon/aws/auth_manager/avp/entities.py index f2c63767299be..8c2e8855b877d 100644 --- a/airflow/providers/amazon/aws/auth_manager/avp/entities.py +++ b/airflow/providers/amazon/aws/auth_manager/avp/entities.py @@ -55,7 +55,7 @@ def get_entity_type(resource_type: AvpEntities) -> str: return AVP_PREFIX_ENTITIES + resource_type.value -def get_action_id(resource_type: AvpEntities, method: ResourceMethod): +def get_action_id(resource_type: AvpEntities, method: ResourceMethod | str): """ Return action id. diff --git a/airflow/providers/amazon/aws/auth_manager/avp/facade.py b/airflow/providers/amazon/aws/auth_manager/avp/facade.py index 010531155ede8..4bb9515004cf6 100644 --- a/airflow/providers/amazon/aws/auth_manager/avp/facade.py +++ b/airflow/providers/amazon/aws/auth_manager/avp/facade.py @@ -75,7 +75,7 @@ def avp_policy_store_id(self): def is_authorized( self, *, - method: ResourceMethod, + method: ResourceMethod | str, entity_type: AvpEntities, user: AwsAuthManagerUser | None, entity_id: str | None = None, @@ -86,7 +86,10 @@ def is_authorized( Check whether the user has permissions to access given resource. - :param method: the method to perform + :param method: the method to perform. + The method can also be a string if the action has been defined in a plugin. + In that case, the action can be anything (e.g. can_do). + See https://github.com/apache/airflow/issues/39144 :param entity_type: the entity type the user accesses :param user: the user :param entity_id: the entity ID the user accesses. If not provided, all entities of the type will be diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 57b9f9ea0c312..f94e4de691d97 100644 --- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -197,7 +197,7 @@ def is_authorized_view( ) def is_authorized_custom_view( - self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None ): return self.avp_facade.is_authorized( method=method, diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index d01b3526bf204..547bc626bbad5 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -268,11 +268,14 @@ def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = ) def is_authorized_custom_view( - self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None ): if not user: user = self.get_user() - fab_action_name = get_fab_action_from_method_map()[method] + if method in get_fab_action_from_method_map(): + fab_action_name = get_fab_action_from_method_map()[method] + else: + fab_action_name = method return (fab_action_name, resource_name) in self._get_user_permissions(user) @provide_session diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 04191c4838c8a..64d33f60659ad 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -95,7 +95,7 @@ def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = raise NotImplementedError() def is_authorized_custom_view( - self, *, method: ResourceMethod, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None ): raise NotImplementedError() diff --git a/tests/providers/fab/auth_manager/test_fab_auth_manager.py b/tests/providers/fab/auth_manager/test_fab_auth_manager.py index 72e63983b512e..30ef64b281aac 100644 --- a/tests/providers/fab/auth_manager/test_fab_auth_manager.py +++ b/tests/providers/fab/auth_manager/test_fab_auth_manager.py @@ -392,10 +392,21 @@ def test_is_authorized_view(self, access_view, user_permissions, expected_result [(ACTION_CAN_READ, "custom_resource2")], False, ), + ( + "DUMMY", + "custom_resource", + [("DUMMY", "custom_resource")], + True, + ), ], ) def test_is_authorized_custom_view( - self, method: ResourceMethod, resource_name: str, user_permissions, expected_result, auth_manager + self, + method: ResourceMethod | str, + resource_name: str, + user_permissions, + expected_result, + auth_manager, ): user = Mock() user.perms = user_permissions