diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 387d968ec157f..6f3e5851c2f34 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -44,7 +44,10 @@ if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod from airflow.api_fastapi.auth.managers.models.batch_apis import ( + IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, + IsAuthorizedPoolRequest, + IsAuthorizedVariableRequest, ) from airflow.api_fastapi.auth.managers.models.resource_details import ( AccessView, @@ -244,6 +247,27 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest): return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])] + def batch_is_authorized_connection( + self, + requests: Sequence[IsAuthorizedConnectionRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.CONNECTION, + "entity_id": cast("ConnectionDetails", request["details"]).conn_id + if request.get("details") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], @@ -251,18 +275,65 @@ def batch_is_authorized_dag( user: AwsAuthManagerUser, ) -> bool: facade_requests: Sequence[IsAuthorizedRequest] = [ - { - "method": request["method"], - "entity_type": AvpEntities.DAG, - "entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None, - "context": { - "dag_entity": { - "string": cast("DagAccessEntity", request["access_entity"]).value, - }, - } - if request.get("access_entity") - else None, - } + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.DAG, + "entity_id": cast("DagDetails", request["details"]).id + if request.get("details") + else None, + "context": { + "dag_entity": { + "string": cast("DagAccessEntity", request["access_entity"]).value, + }, + } + if request.get("access_entity") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + + def batch_is_authorized_pool( + self, + requests: Sequence[IsAuthorizedPoolRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.POOL, + "entity_id": cast("PoolDetails", request["details"]).name + if request.get("details") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + + def batch_is_authorized_variable( + self, + requests: Sequence[IsAuthorizedVariableRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.VARIABLE, + "entity_id": cast("VariableDetails", request["details"]).key + if request.get("details") + else None, + }, + ) for request in requests ] return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py index eab473ea9a5be..70d5a31986ec4 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -439,6 +439,40 @@ def test_filter_authorized_menu_items(self, auth_manager): ) assert result == [MenuItem.VARIABLES, MenuItem.DAGS] + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_connection( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_connection( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": ConnectionDetails(conn_id="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.CONNECTION, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.CONNECTION, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + @patch.object(AwsAuthManager, "avp_facade") def test_batch_is_authorized_dag( self, @@ -510,6 +544,74 @@ def test_batch_is_authorized_dag( ) assert result + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_pool( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_pool( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": PoolDetails(name="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.POOL, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.POOL, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_variable( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_variable( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": VariableDetails(key="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.VARIABLE, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.VARIABLE, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + @pytest.mark.parametrize( "method, user, expected_result", [