diff --git a/airflow-core/docs/core-concepts/auth-manager/index.rst b/airflow-core/docs/core-concepts/auth-manager/index.rst index a1959da497fa1..b0c99abf4bf54 100644 --- a/airflow-core/docs/core-concepts/auth-manager/index.rst +++ b/airflow-core/docs/core-concepts/auth-manager/index.rst @@ -176,10 +176,7 @@ Optional methods recommended to override for optimization The following methods aren't required to override to have a functional Airflow auth manager. However, it is recommended to override these to make your auth manager faster (and potentially less costly): -* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it will call ``is_authorized_connection`` for every single item. * ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it will call ``is_authorized_dag`` for every single item. -* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it will call ``is_authorized_pool`` for every single item. -* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it will call ``is_authorized_variable`` for every single item. * ``get_authorized_dag_ids``: Return the list of DAG IDs the user has access to. If not overridden, it will call ``is_authorized_dag`` for every single DAG available in the environment. CLI diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index cd766652e2d96..108e4a1289350 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -47,10 +47,7 @@ from sqlalchemy.orm import Session from airflow.api_fastapi.auth.managers.models.batch_apis import ( - IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, - IsAuthorizedPoolRequest, - IsAuthorizedVariableRequest, ) from airflow.api_fastapi.auth.managers.models.resource_details import ( AccessView, @@ -307,27 +304,6 @@ def filter_authorized_menu_items(self, menu_items: list[MenuItem], *, user: T) - :param user: the user """ - def batch_is_authorized_connection( - self, - requests: Sequence[IsAuthorizedConnectionRequest], - *, - user: T, - ) -> bool: - """ - Batch version of ``is_authorized_connection``. - - By default, calls individually the ``is_authorized_connection`` API on each item in the list of - requests, which can lead to some poor performance. It is recommended to override this method in the auth - manager implementation to provide a more efficient implementation. - - :param requests: a list of requests containing the parameters for ``is_authorized_connection`` - :param user: the user to performing the action - """ - return all( - self.is_authorized_connection(method=request["method"], details=request.get("details"), user=user) - for request in requests - ) - def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], @@ -354,48 +330,6 @@ def batch_is_authorized_dag( for request in requests ) - def batch_is_authorized_pool( - self, - requests: Sequence[IsAuthorizedPoolRequest], - *, - user: T, - ) -> bool: - """ - Batch version of ``is_authorized_pool``. - - By default, calls individually the ``is_authorized_pool`` API on each item in the list of - requests. Can lead to some poor performance. It is recommended to override this method in the auth - manager implementation to provide a more efficient implementation. - - :param requests: a list of requests containing the parameters for ``is_authorized_pool`` - :param user: the user to performing the action - """ - return all( - self.is_authorized_pool(method=request["method"], details=request.get("details"), user=user) - for request in requests - ) - - def batch_is_authorized_variable( - self, - requests: Sequence[IsAuthorizedVariableRequest], - *, - user: T, - ) -> bool: - """ - Batch version of ``is_authorized_variable``. - - By default, calls individually the ``is_authorized_variable`` API on each item in the list of - requests. Can lead to some poor performance. It is recommended to override this method in the auth - manager implementation to provide a more efficient implementation. - - :param requests: a list of requests containing the parameters for ``is_authorized_variable`` - :param user: the user to performing the action - """ - return all( - self.is_authorized_variable(method=request["method"], details=request.get("details"), user=user) - for request in requests - ) - @provide_session def get_authorized_dag_ids( self, diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/models/batch_apis.py b/airflow-core/src/airflow/api_fastapi/auth/managers/models/batch_apis.py index 2fe11b659af6e..5acdd3edee5f2 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/models/batch_apis.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/models/batch_apis.py @@ -22,38 +22,14 @@ if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod from airflow.api_fastapi.auth.managers.models.resource_details import ( - ConnectionDetails, DagAccessEntity, DagDetails, - PoolDetails, - VariableDetails, ) -class IsAuthorizedConnectionRequest(TypedDict, total=False): - """Represent the parameters of ``is_authorized_connection`` API in the auth manager.""" - - method: ResourceMethod - details: ConnectionDetails | None - - class IsAuthorizedDagRequest(TypedDict, total=False): """Represent the parameters of ``is_authorized_dag`` API in the auth manager.""" method: ResourceMethod access_entity: DagAccessEntity | None details: DagDetails | None - - -class IsAuthorizedPoolRequest(TypedDict, total=False): - """Represent the parameters of ``is_authorized_pool`` API in the auth manager.""" - - method: ResourceMethod - details: PoolDetails | None - - -class IsAuthorizedVariableRequest(TypedDict, total=False): - """Represent the parameters of ``is_authorized_variable`` API in the auth manager.""" - - method: ResourceMethod - details: VariableDetails | None diff --git a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py index aad69253ef289..a0fc0ab03431f 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py +++ b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py @@ -261,70 +261,6 @@ def test_batch_is_authorized_dag(self, mock_is_authorized_dag, auth_manager, ret ) assert result == expected - @pytest.mark.parametrize( - "return_values, expected", - [ - ([False, False], False), - ([True, False], False), - ([True, True], True), - ], - ) - @patch.object(EmptyAuthManager, "is_authorized_connection") - def test_batch_is_authorized_connection( - self, mock_is_authorized_connection, auth_manager, return_values, expected - ): - mock_is_authorized_connection.side_effect = return_values - result = auth_manager.batch_is_authorized_connection( - [ - {"method": "GET", "details": ConnectionDetails(conn_id="conn1")}, - {"method": "GET", "details": ConnectionDetails(conn_id="conn2")}, - ], - user=Mock(), - ) - assert result == expected - - @pytest.mark.parametrize( - "return_values, expected", - [ - ([False, False], False), - ([True, False], False), - ([True, True], True), - ], - ) - @patch.object(EmptyAuthManager, "is_authorized_pool") - def test_batch_is_authorized_pool(self, mock_is_authorized_pool, auth_manager, return_values, expected): - mock_is_authorized_pool.side_effect = return_values - result = auth_manager.batch_is_authorized_pool( - [ - {"method": "GET", "details": PoolDetails(name="pool1")}, - {"method": "GET", "details": PoolDetails(name="pool2")}, - ], - user=Mock(), - ) - assert result == expected - - @pytest.mark.parametrize( - "return_values, expected", - [ - ([False, False], False), - ([True, False], False), - ([True, True], True), - ], - ) - @patch.object(EmptyAuthManager, "is_authorized_variable") - def test_batch_is_authorized_variable( - self, mock_is_authorized_variable, auth_manager, return_values, expected - ): - mock_is_authorized_variable.side_effect = return_values - result = auth_manager.batch_is_authorized_variable( - [ - {"method": "GET", "details": VariableDetails(key="var1")}, - {"method": "GET", "details": VariableDetails(key="var2")}, - ], - user=Mock(), - ) - assert result == expected - @pytest.mark.parametrize( "access_per_dag, dag_ids, expected", [ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index d427f210e11e7..387d968ec157f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -44,10 +44,7 @@ if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod from airflow.api_fastapi.auth.managers.models.batch_apis import ( - IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, - IsAuthorizedPoolRequest, - IsAuthorizedVariableRequest, ) from airflow.api_fastapi.auth.managers.models.resource_details import ( AccessView, @@ -247,24 +244,6 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest): return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])] - def batch_is_authorized_connection( - self, - requests: Sequence[IsAuthorizedConnectionRequest], - *, - user: AwsAuthManagerUser, - ) -> bool: - facade_requests: Sequence[IsAuthorizedRequest] = [ - { - "method": request["method"], - "entity_type": AvpEntities.CONNECTION, - "entity_id": cast("ConnectionDetails", request["details"]).conn_id - if request.get("details") - else None, - } - for request in requests - ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) - def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], @@ -288,40 +267,6 @@ def batch_is_authorized_dag( ] return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) - def batch_is_authorized_pool( - self, - requests: Sequence[IsAuthorizedPoolRequest], - *, - user: AwsAuthManagerUser, - ) -> bool: - facade_requests: Sequence[IsAuthorizedRequest] = [ - { - "method": request["method"], - "entity_type": AvpEntities.POOL, - "entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None, - } - for request in requests - ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) - - def batch_is_authorized_variable( - self, - requests: Sequence[IsAuthorizedVariableRequest], - *, - user: AwsAuthManagerUser, - ) -> bool: - facade_requests: Sequence[IsAuthorizedRequest] = [ - { - "method": request["method"], - "entity_type": AvpEntities.VARIABLE, - "entity_id": cast("VariableDetails", request["details"]).key - if request.get("details") - else None, - } - for request in requests - ] - return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) - def filter_authorized_dag_ids( self, *, diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py index 055e50ac9a064..eab473ea9a5be 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -368,37 +368,6 @@ def test_is_authorized_view(self, mock_avp_facade, access_view, user, expected_u ) assert result - @patch.object(AwsAuthManager, "avp_facade") - def test_batch_is_authorized_connection( - self, - mock_avp_facade, - auth_manager, - ): - batch_is_authorized = Mock(return_value=True) - mock_avp_facade.batch_is_authorized = batch_is_authorized - - result = auth_manager.batch_is_authorized_connection( - requests=[{"method": "GET"}, {"method": "GET", "details": ConnectionDetails(conn_id="conn_id")}], - user=mock, - ) - - batch_is_authorized.assert_called_once_with( - requests=[ - { - "method": "GET", - "entity_type": AvpEntities.CONNECTION, - "entity_id": None, - }, - { - "method": "GET", - "entity_type": AvpEntities.CONNECTION, - "entity_id": "conn_id", - }, - ], - user=ANY, - ) - assert result - def test_filter_authorized_menu_items(self, auth_manager): batch_is_authorized_output = [ { @@ -541,68 +510,6 @@ def test_batch_is_authorized_dag( ) assert result - @patch.object(AwsAuthManager, "avp_facade") - def test_batch_is_authorized_pool( - self, - mock_avp_facade, - auth_manager, - ): - batch_is_authorized = Mock(return_value=True) - mock_avp_facade.batch_is_authorized = batch_is_authorized - - result = auth_manager.batch_is_authorized_pool( - requests=[{"method": "GET"}, {"method": "GET", "details": PoolDetails(name="pool1")}], - user=mock, - ) - - batch_is_authorized.assert_called_once_with( - requests=[ - { - "method": "GET", - "entity_type": AvpEntities.POOL, - "entity_id": None, - }, - { - "method": "GET", - "entity_type": AvpEntities.POOL, - "entity_id": "pool1", - }, - ], - user=ANY, - ) - assert result - - @patch.object(AwsAuthManager, "avp_facade") - def test_batch_is_authorized_variable( - self, - mock_avp_facade, - auth_manager, - ): - batch_is_authorized = Mock(return_value=True) - mock_avp_facade.batch_is_authorized = batch_is_authorized - - result = auth_manager.batch_is_authorized_variable( - requests=[{"method": "GET"}, {"method": "GET", "details": VariableDetails(key="var1")}], - user=mock, - ) - - batch_is_authorized.assert_called_once_with( - requests=[ - { - "method": "GET", - "entity_type": AvpEntities.VARIABLE, - "entity_id": None, - }, - { - "method": "GET", - "entity_type": AvpEntities.VARIABLE, - "entity_id": "var1", - }, - ], - user=ANY, - ) - assert result - @pytest.mark.parametrize( "method, user, expected_result", [ diff --git a/providers/fab/src/airflow/providers/fab/www/auth.py b/providers/fab/src/airflow/providers/fab/www/auth.py index 5cd35e4600dbd..e25f77c87589b 100644 --- a/providers/fab/src/airflow/providers/fab/www/auth.py +++ b/providers/fab/src/airflow/providers/fab/www/auth.py @@ -46,10 +46,7 @@ if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod from airflow.api_fastapi.auth.managers.models.batch_apis import ( - IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, - IsAuthorizedPoolRequest, - IsAuthorizedVariableRequest, ) from airflow.models import DagRun, Pool, TaskInstance, Variable from airflow.models.connection import Connection @@ -170,15 +167,13 @@ def has_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): connections: set[Connection] = set(args[1]) - requests: Sequence[IsAuthorizedConnectionRequest] = [ - { - "method": method, - "details": ConnectionDetails(conn_id=connection.conn_id), - } + is_authorized = all( + get_auth_manager().is_authorized_connection( + method=method, + details=ConnectionDetails(conn_id=connection.conn_id), + user=get_auth_manager().get_user(), + ) for connection in connections - ] - is_authorized = get_auth_manager().batch_is_authorized_connection( - requests, user=get_auth_manager().get_user() ) return _has_access( is_authorized=is_authorized, @@ -284,15 +279,11 @@ def has_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): pools: set[Pool] = set(args[1]) - requests: Sequence[IsAuthorizedPoolRequest] = [ - { - "method": method, - "details": PoolDetails(name=pool.pool), - } + is_authorized = all( + get_auth_manager().is_authorized_pool( + method=method, details=PoolDetails(name=pool.pool), user=get_auth_manager().get_user() + ) for pool in pools - ] - is_authorized = get_auth_manager().batch_is_authorized_pool( - requests, user=get_auth_manager().get_user() ) return _has_access( is_authorized=is_authorized, @@ -310,23 +301,15 @@ def has_access_variable(method: ResourceMethod) -> Callable[[T], T]: def has_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): - if len(args) == 1: - # No items provided - is_authorized = get_auth_manager().is_authorized_variable( - method=method, user=get_auth_manager().get_user() - ) - else: - variables: set[Variable] = set(args[1]) - requests: Sequence[IsAuthorizedVariableRequest] = [ - { - "method": method, - "details": VariableDetails(key=variable.key), - } - for variable in variables - ] - is_authorized = get_auth_manager().batch_is_authorized_variable( - requests, user=get_auth_manager().get_user() + variables: set[Variable] = set(args[1]) + is_authorized = all( + get_auth_manager().is_authorized_variable( + method=method, + details=VariableDetails(key=variable.key), + user=get_auth_manager().get_user(), ) + for variable in variables + ) return _has_access( is_authorized=is_authorized, func=func, diff --git a/providers/fab/tests/unit/fab/www/test_auth.py b/providers/fab/tests/unit/fab/www/test_auth.py index dfa7df6e4c43f..b20ee27bd34a5 100644 --- a/providers/fab/tests/unit/fab/www/test_auth.py +++ b/providers/fab/tests/unit/fab/www/test_auth.py @@ -135,11 +135,11 @@ def get_variable(): [ ( "has_access_connection", - "batch_is_authorized_connection", + "is_authorized_connection", "get_connection", ), - ("has_access_pool", "batch_is_authorized_pool", "get_pool"), - ("has_access_variable", "batch_is_authorized_variable", "get_variable"), + ("has_access_pool", "is_authorized_pool", "get_pool"), + ("has_access_variable", "is_authorized_variable", "get_variable"), ], ) class TestHasAccessWithDetails: diff --git a/providers/fab/www-hash.txt b/providers/fab/www-hash.txt index 9ac61b6cd0166..e84aa2d163c59 100644 --- a/providers/fab/www-hash.txt +++ b/providers/fab/www-hash.txt @@ -1 +1 @@ -bba05295e6d4ef8f0bfe766b77cef4d90d62e86f3a2162de32d6e94979b236c7 +06230f1bd1b77ee1be84feaaa61b26d7dcdb2f4665ec619e459b7d8b383f3ca0