diff --git a/airflow-core/docs/core-concepts/auth-manager/index.rst b/airflow-core/docs/core-concepts/auth-manager/index.rst index fffc89fc4d572..13a20eeaa3cfa 100644 --- a/airflow-core/docs/core-concepts/auth-manager/index.rst +++ b/airflow-core/docs/core-concepts/auth-manager/index.rst @@ -178,14 +178,14 @@ Optional methods recommended to override for optimization The following methods aren't required to override to have a functional Airflow auth manager. However, it is recommended to override these to make your auth manager faster (and potentially less costly): -* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it will call ``is_authorized_connection`` for every single item. -* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it will call ``is_authorized_dag`` for every single item. -* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it will call ``is_authorized_pool`` for every single item. -* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it will call ``is_authorized_variable`` for every single item. -* ``get_authorized_dag_ids``: Return the list of Dag IDs the user has access to. If not overridden, it will call ``is_authorized_dag`` for every single Dag available in the environment. - - * Note: To filter the results of ``get_authorized_dag_ids``, it is recommended that you define the filtering logic in your ``filter_authorized_dag_ids`` method. For example, this may be useful if you rely on per-Dag access controls derived from one or more fields on a given Dag (e.g. Dag tags). - * This method requires an active session with the Airflow metadata database. As such, overriding the ``get_authorized_dag_ids`` method is an advanced use case, which should be considered carefully -- it is recommended you refer to the :doc:`../../database-erd-ref`. +* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it calls ``is_authorized_connection`` for every single item. +* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it calls ``is_authorized_dag`` for every single item. +* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it calls ``is_authorized_pool`` for every single item. +* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it calls ``is_authorized_variable`` for every single item. +* ``filter_authorized_connections``: Given a list of connection IDs (``conn_id``), return the list of connection IDs the user has access to. If not overridden, it calls ``is_authorized_connection`` for every single connection passed as parameter. +* ``filter_authorized_dag_ids``: Given a list of Dag IDs, return the list of Dag IDs the user has access to. If not overridden, it calls ``is_authorized_dag`` for every single Dag passes as parameter. +* ``filter_authorized_pools``: Given a list of pool names, return the list of pool names the user has access to. If not overridden, it calls ``is_authorized_pool`` for every single pool passed as parameter. +* ``filter_authorized_variables``: Given a list of variable keys, return the list of variable keys the user has access to. If not overridden, it calls ``is_authorized_variable`` for every single variable passed as parameter. CLI ^^^ diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index d57dd3cdc39f1..456b29c0cf422 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -19,6 +19,7 @@ import logging from abc import ABCMeta, abstractmethod +from collections import defaultdict from functools import cache from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar @@ -26,7 +27,13 @@ from sqlalchemy import select from airflow.api_fastapi.auth.managers.models.base_user import BaseUser -from airflow.api_fastapi.auth.managers.models.resource_details import BackfillDetails, DagDetails +from airflow.api_fastapi.auth.managers.models.resource_details import ( + BackfillDetails, + ConnectionDetails, + DagDetails, + PoolDetails, + VariableDetails, +) from airflow.api_fastapi.auth.tokens import ( JWTGenerator, JWTValidator, @@ -35,7 +42,9 @@ ) from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem from airflow.configuration import conf -from airflow.models import DagModel +from airflow.models import Connection, DagModel, Pool, Variable +from airflow.models.dagbundle import DagBundleModel +from airflow.models.team import Team, dag_bundle_team_association_table from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -56,10 +65,7 @@ AssetAliasDetails, AssetDetails, ConfigurationDetails, - ConnectionDetails, DagAccessEntity, - PoolDetails, - VariableDetails, ) from airflow.cli.cli_config import CLICommand @@ -427,16 +433,34 @@ def get_authorized_dag_ids( """ Get DAGs the user has access to. - By default, reads all the DAGs and check individually if the user has permissions to access the DAG. - Can lead to some poor performance. It is recommended to override this method in the auth manager - implementation to provide a more efficient implementation. - :param user: the user :param method: the method to filter on :param session: the session """ - dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - return self.filter_authorized_dag_ids(dag_ids=dag_ids, method=method, user=user) + stmt = ( + select(DagModel.dag_id, Team.name) + .join(DagBundleModel, DagModel.bundle_name == DagBundleModel.name) + .join( + dag_bundle_team_association_table, + DagBundleModel.name == dag_bundle_team_association_table.c.dag_bundle_name, + isouter=True, + ) + .join(Team, Team.id == dag_bundle_team_association_table.c.team_id, isouter=True) + ) + rows = session.execute(stmt).all() + dags_by_team: dict[str | None, set[str]] = defaultdict(set) + for dag_id, team_name in rows: + dags_by_team[team_name].add(dag_id) + + dag_ids: set[str] = set() + for team_name, team_dag_ids in dags_by_team.items(): + dag_ids.update( + self.filter_authorized_dag_ids( + dag_ids=team_dag_ids, user=user, method=method, team_name=team_name + ) + ) + + return dag_ids def filter_authorized_dag_ids( self, @@ -444,19 +468,208 @@ def filter_authorized_dag_ids( dag_ids: set[str], user: T, method: ResourceMethod = "GET", + team_name: str | None = None, ) -> set[str]: """ Filter DAGs the user has access to. - :param dag_ids: the list of DAG ids + By default, check individually if the user has permissions to access the DAG. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + + :param dag_ids: the set of DAG ids + :param user: the user + :param method: the method to filter on + :param team_name: the name of the team associated to the Dags if Airflow environment runs in + multi-team mode + """ + + def _is_authorized_dag_id(dag_id: str): + return self.is_authorized_dag( + method=method, details=DagDetails(id=dag_id, team_name=team_name), user=user + ) + + return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(dag_id)} + + @provide_session + def get_authorized_connections( + self, + *, + user: T, + method: ResourceMethod = "GET", + session: Session = NEW_SESSION, + ) -> set[str]: + """ + Get connection ids (``conn_id``) the user has access to. + + :param user: the user + :param method: the method to filter on + :param session: the session + """ + stmt = select(Connection.conn_id, Team.name).join(Team, Connection.team_id == Team.id, isouter=True) + rows = session.execute(stmt).all() + connections_by_team: dict[str | None, set[str]] = defaultdict(set) + for conn_id, team_name in rows: + connections_by_team[team_name].add(conn_id) + + conn_ids: set[str] = set() + for team_name, team_conn_ids in connections_by_team.items(): + conn_ids.update( + self.filter_authorized_connections( + conn_ids=team_conn_ids, user=user, method=method, team_name=team_name + ) + ) + + return conn_ids + + def filter_authorized_connections( + self, + *, + conn_ids: set[str], + user: T, + method: ResourceMethod = "GET", + team_name: str | None = None, + ) -> set[str]: + """ + Filter connections the user has access to. + + By default, check individually if the user has permissions to access the connection. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + + :param conn_ids: the set of connection ids (``conn_id``) :param user: the user :param method: the method to filter on + :param team_name: the name of the team associated to the connections if Airflow environment runs in + multi-team mode + """ + + def _is_authorized_connection(conn_id: str): + return self.is_authorized_connection( + method=method, details=ConnectionDetails(conn_id=conn_id, team_name=team_name), user=user + ) + + return {conn_id for conn_id in conn_ids if _is_authorized_connection(conn_id)} + + @provide_session + def get_authorized_variables( + self, + *, + user: T, + method: ResourceMethod = "GET", + session: Session = NEW_SESSION, + ) -> set[str]: """ + Get variable keys the user has access to. - def _is_authorized_dag_id(method: ResourceMethod, dag_id: str): - return self.is_authorized_dag(method=method, details=DagDetails(id=dag_id), user=user) + :param user: the user + :param method: the method to filter on + :param session: the session + """ + stmt = select(Variable.key, Team.name).join(Team, Variable.team_id == Team.id, isouter=True) + rows = session.execute(stmt).all() + variables_by_team: dict[str | None, set[str]] = defaultdict(set) + for var_key, team_name in rows: + variables_by_team[team_name].add(var_key) + + var_keys: set[str] = set() + for team_name, team_var_keys in variables_by_team.items(): + var_keys.update( + self.filter_authorized_variables( + variable_keys=team_var_keys, user=user, method=method, team_name=team_name + ) + ) + + return var_keys + + def filter_authorized_variables( + self, + *, + variable_keys: set[str], + user: T, + method: ResourceMethod = "GET", + team_name: str | None = None, + ) -> set[str]: + """ + Filter variables the user has access to. + + By default, check individually if the user has permissions to access the variable. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + + :param variable_keys: the set of variable keys + :param user: the user + :param method: the method to filter on + :param team_name: the name of the team associated to the connections if Airflow environment runs in + multi-team mode + """ + + def _is_authorized_variable(var_key: str): + return self.is_authorized_variable( + method=method, details=VariableDetails(key=var_key, team_name=team_name), user=user + ) + + return {var_key for var_key in variable_keys if _is_authorized_variable(var_key)} + + @provide_session + def get_authorized_pools( + self, + *, + user: T, + method: ResourceMethod = "GET", + session: Session = NEW_SESSION, + ) -> set[str]: + """ + Get pools the user has access to. + + :param user: the user + :param method: the method to filter on + :param session: the session + """ + stmt = select(Pool.pool, Team.name).join(Team, Pool.team_id == Team.id, isouter=True) + rows = session.execute(stmt).all() + pools_by_team: dict[str | None, set[str]] = defaultdict(set) + for pool_name, team_name in rows: + pools_by_team[team_name].add(pool_name) + + pool_names: set[str] = set() + for team_name, team_pool_names in pools_by_team.items(): + pool_names.update( + self.filter_authorized_pools( + pool_names=team_pool_names, user=user, method=method, team_name=team_name + ) + ) + + return pool_names + + def filter_authorized_pools( + self, + *, + pool_names: set[str], + user: T, + method: ResourceMethod = "GET", + team_name: str | None = None, + ) -> set[str]: + """ + Filter pools the user has access to. + + By default, check individually if the user has permissions to access the pool. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + + :param pool_names: the set of pool names + :param user: the user + :param method: the method to filter on + :param team_name: the name of the team associated to the connections if Airflow environment runs in + multi-team mode + """ + + def _is_authorized_pool(name: str): + return self.is_authorized_pool( + method=method, details=PoolDetails(name=name, team_name=team_name), user=user + ) - return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(method, dag_id)} + return {pool_name for pool_name in pool_names if _is_authorized_pool(pool_name)} @staticmethod def get_cli_commands() -> list[CLICommand]: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py index 295c6f156f038..28db5ae4fef4b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py @@ -43,7 +43,11 @@ ConnectionTestResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import requires_access_connection, requires_access_connection_bulk +from airflow.api_fastapi.core_api.security import ( + ReadableConnectionsFilterDep, + requires_access_connection, + requires_access_connection_bulk, +) from airflow.api_fastapi.core_api.services.public.connections import ( BulkConnectionService, update_orm_from_pydantic, @@ -117,13 +121,14 @@ def get_connections( ).dynamic_depends() ), ], + readable_connections_filter: ReadableConnectionsFilterDep, session: SessionDep, connection_id_pattern: QueryConnectionIdPatternSearch, ) -> ConnectionCollectionResponse: """Get all connection entries.""" connection_select, total_entries = paginated_select( statement=select(Connection), - filters=[connection_id_pattern], + filters=[connection_id_pattern, readable_connections_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py index 835c2a62c3fca..6a7e2f646d979 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py @@ -40,7 +40,11 @@ PoolResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import requires_access_pool, requires_access_pool_bulk +from airflow.api_fastapi.core_api.security import ( + ReadablePoolsFilterDep, + requires_access_pool, + requires_access_pool_bulk, +) from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService from airflow.api_fastapi.logging.decorators import action_logging from airflow.models.pool import Pool @@ -103,12 +107,13 @@ def get_pools( Depends(SortParam(["id", "pool"], Pool, to_replace={"name": "pool"}).dynamic_depends()), ], pool_name_pattern: QueryPoolNamePatternSearch, + readable_pools_filter: ReadablePoolsFilterDep, session: SessionDep, ) -> PoolCollectionResponse: """Get all pools entries.""" pools_select, total_entries = paginated_select( statement=select(Pool), - filters=[pool_name_pattern], + filters=[pool_name_pattern, readable_pools_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py index eb111c0c6af7f..36fc5be44b1a8 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py @@ -38,7 +38,11 @@ VariableResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import requires_access_variable, requires_access_variable_bulk +from airflow.api_fastapi.core_api.security import ( + ReadableVariablesFilterDep, + requires_access_variable, + requires_access_variable_bulk, +) from airflow.api_fastapi.core_api.services.public.variables import BulkVariableService from airflow.api_fastapi.logging.decorators import action_logging from airflow.models.variable import Variable @@ -99,13 +103,14 @@ def get_variables( ).dynamic_depends() ), ], + readable_variables_filter: ReadableVariablesFilterDep, session: SessionDep, - varaible_key_pattern: QueryVariableKeyPatternSearch, + variable_key_pattern: QueryVariableKeyPatternSearch, ) -> VariableCollectionResponse: """Get all Variables entries.""" variable_select, total_entries = paginated_select( statement=select(Variable), - filters=[varaible_key_pattern], + filters=[variable_key_pattern, readable_variables_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index 84bd0ccdd2989..975ad23e116c3 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -233,6 +233,36 @@ def inner( return inner +class PermittedPoolFilter(OrmClause[set[str]]): + """A parameter that filters the permitted pools for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(Pool.pool.in_(self.value)) + + +def permitted_pool_filter_factory( + method: ResourceMethod, +) -> Callable[[Request, BaseUser], PermittedPoolFilter]: + """ + Create a callable for Depends in FastAPI that returns a filter of the permitted pools for the user. + + :param method: whether filter readable or writable. + """ + + def depends_permitted_pools_filter( + request: Request, + user: GetUserDep, + ) -> PermittedPoolFilter: + auth_manager: BaseAuthManager = request.app.state.auth_manager + authorized_pools: set[str] = auth_manager.get_authorized_pools(user=user, method=method) + return PermittedPoolFilter(authorized_pools) + + return depends_permitted_pools_filter + + +ReadablePoolsFilterDep = Annotated[PermittedPoolFilter, Depends(permitted_pool_filter_factory("GET"))] + + def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, @@ -281,6 +311,38 @@ def inner( return inner +class PermittedConnectionFilter(OrmClause[set[str]]): + """A parameter that filters the permitted connections for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(Connection.conn_id.in_(self.value)) + + +def permitted_connection_filter_factory( + method: ResourceMethod, +) -> Callable[[Request, BaseUser], PermittedConnectionFilter]: + """ + Create a callable for Depends in FastAPI that returns a filter of the permitted connections for the user. + + :param method: whether filter readable or writable. + """ + + def depends_permitted_connections_filter( + request: Request, + user: GetUserDep, + ) -> PermittedConnectionFilter: + auth_manager: BaseAuthManager = request.app.state.auth_manager + authorized_connections: set[str] = auth_manager.get_authorized_connections(user=user, method=method) + return PermittedConnectionFilter(authorized_connections) + + return depends_permitted_connections_filter + + +ReadableConnectionsFilterDep = Annotated[ + PermittedConnectionFilter, Depends(permitted_connection_filter_factory("GET")) +] + + def requires_access_connection(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, @@ -349,6 +411,38 @@ def inner( return inner +class PermittedVariableFilter(OrmClause[set[str]]): + """A parameter that filters the permitted variables for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(Variable.key.in_(self.value)) + + +def permitted_variable_filter_factory( + method: ResourceMethod, +) -> Callable[[Request, BaseUser], PermittedVariableFilter]: + """ + Create a callable for Depends in FastAPI that returns a filter of the permitted variables for the user. + + :param method: whether filter readable or writable. + """ + + def depends_permitted_variables_filter( + request: Request, + user: GetUserDep, + ) -> PermittedVariableFilter: + auth_manager: BaseAuthManager = request.app.state.auth_manager + authorized_variables: set[str] = auth_manager.get_authorized_variables(user=user, method=method) + return PermittedVariableFilter(authorized_variables) + + return depends_permitted_variables_filter + + +ReadableVariablesFilterDep = Annotated[ + PermittedVariableFilter, Depends(permitted_variable_filter_factory("GET")) +] + + def requires_access_variable(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, diff --git a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py index fde846656096a..010cf24481f7a 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py +++ b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py @@ -326,42 +326,226 @@ def test_batch_is_authorized_variable( assert result == expected @pytest.mark.parametrize( - "access_per_dag, dag_ids, expected", + "access_per_dag, access_per_team, rows, expected", [ + # Without teams # No access to any dag ( {}, - ["dag1", "dag2"], + {}, + [("dag1", None), ("dag2", None)], set(), ), # Access to specific dags ( {"dag1": True}, - ["dag1", "dag2"], + {}, + [("dag1", None), ("dag2", None)], {"dag1"}, ), + # With teams + # No access to any dag + ( + {}, + {}, + [("dag1", "team1"), ("dag2", "team2")], + set(), + ), + # Access to a specific team + ( + {}, + {"team1": True}, + [("dag1", "team1"), ("dag2", "team1"), ("dag3", "team2")], + {"dag1", "dag2"}, + ), ], ) - def test_get_authorized_dag_ids(self, auth_manager, access_per_dag: dict, dag_ids: list, expected: set): + def test_get_authorized_dag_ids( + self, auth_manager, access_per_dag: dict, access_per_team: dict, rows: list, expected: set + ): def side_effect_func( *, method: ResourceMethod, + user: BaseAuthManagerUserTest, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseAuthManagerUserTest | None = None, ): if not details: return False - return access_per_dag.get(details.id, False) + return access_per_dag.get(details.id, False) or access_per_team.get(details.team_name, False) auth_manager.is_authorized_dag = MagicMock(side_effect=side_effect_func) user = Mock() session = Mock() - dags = [] - for dag_id in dag_ids: - mock = Mock() - mock.dag_id = dag_id - dags.append(mock) - session.execute.return_value = dags + session.execute.return_value.all.return_value = rows result = auth_manager.get_authorized_dag_ids(user=user, session=session) assert result == expected + + @pytest.mark.parametrize( + "access_per_connection, access_per_team, rows, expected", + [ + # Without teams + # No access to any connection + ( + {}, + {}, + [("conn1", None), ("conn2", None)], + set(), + ), + # Access to specific connections + ( + {"conn1": True}, + {}, + [("conn1", None), ("conn2", None)], + {"conn1"}, + ), + # With teams + # No access to any connection + ( + {}, + {}, + [("conn1", "team1"), ("conn2", "team2")], + set(), + ), + # Access to a specific team + ( + {}, + {"team1": True}, + [("conn1", "team1"), ("conn2", "team1"), ("conn3", "team2")], + {"conn1", "conn2"}, + ), + ], + ) + def test_get_authorized_connections( + self, auth_manager, access_per_connection: dict, access_per_team: dict, rows: list, expected: set + ): + def side_effect_func( + *, + method: ResourceMethod, + user: BaseAuthManagerUserTest, + details: ConnectionDetails | None = None, + ): + if not details: + return False + return access_per_connection.get(details.conn_id, False) or access_per_team.get( + details.team_name, False + ) + + auth_manager.is_authorized_connection = MagicMock(side_effect=side_effect_func) + user = Mock() + session = Mock() + session.execute.return_value.all.return_value = rows + result = auth_manager.get_authorized_connections(user=user, session=session) + assert result == expected + + @pytest.mark.parametrize( + "access_per_variable, access_per_team, rows, expected", + [ + # Without teams + # No access to any variable + ( + {}, + {}, + [("var1", None), ("var2", None)], + set(), + ), + # Access to specific variables + ( + {"var1": True}, + {}, + [("var1", None), ("var2", None)], + {"var1"}, + ), + # With teams + # No access to any variable + ( + {}, + {}, + [("var1", "team1"), ("var2", "team2")], + set(), + ), + # Access to a specific team + ( + {}, + {"team1": True}, + [("var1", "team1"), ("var2", "team1"), ("var3", "team2")], + {"var1", "var2"}, + ), + ], + ) + def test_get_authorized_variables( + self, auth_manager, access_per_variable: dict, access_per_team: dict, rows: list, expected: set + ): + def side_effect_func( + *, + method: ResourceMethod, + user: BaseAuthManagerUserTest, + details: VariableDetails | None = None, + ): + if not details: + return False + return access_per_variable.get(details.key, False) or access_per_team.get( + details.team_name, False + ) + + auth_manager.is_authorized_variable = MagicMock(side_effect=side_effect_func) + user = Mock() + session = Mock() + session.execute.return_value.all.return_value = rows + result = auth_manager.get_authorized_variables(user=user, session=session) + assert result == expected + + @pytest.mark.parametrize( + "access_per_pool, access_per_team, rows, expected", + [ + # Without teams + # No access to any pool + ( + {}, + {}, + [("pool1", None), ("pool2", None)], + set(), + ), + # Access to specific pools + ( + {"pool1": True}, + {}, + [("pool1", None), ("pool2", None)], + {"pool1"}, + ), + # With teams + # No access to any pool + ( + {}, + {}, + [("pool1", "team1"), ("pool2", "team2")], + set(), + ), + # Access to a specific team + ( + {}, + {"team1": True}, + [("pool1", "team1"), ("pool2", "team1"), ("pool3", "team2")], + {"pool1", "pool2"}, + ), + ], + ) + def test_get_authorized_pools( + self, auth_manager, access_per_pool: dict, access_per_team: dict, rows: list, expected: set + ): + def side_effect_func( + *, + method: ResourceMethod, + user: BaseAuthManagerUserTest, + details: PoolDetails | None = None, + ): + if not details: + return False + return access_per_pool.get(details.name, False) or access_per_team.get(details.team_name, False) + + auth_manager.is_authorized_pool = MagicMock(side_effect=side_effect_func) + user = Mock() + session = Mock() + session.execute.return_value.all.return_value = rows + result = auth_manager.get_authorized_pools(user=user, session=session) + assert result == expected diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index 1dc77cede75b7..b8f60a8676286 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -232,6 +232,20 @@ def test_should_respond_403(self, unauthorized_test_client): response = unauthorized_test_client.get("/connections", params={}) assert response.status_code == 403 + @mock.patch( + "airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_connections" + ) + def test_should_call_get_authorized_connections(self, mock_get_authorized_connections, test_client): + self.create_connections() + mock_get_authorized_connections.return_value = {TEST_CONN_ID} + response = test_client.get("/connections") + mock_get_authorized_connections.assert_called_once_with(user=mock.ANY, method="GET") + assert response.status_code == 200 + body = response.json() + + assert body["total_entries"] == 1 + assert [connection["connection_id"] for connection in body["connections"]] == [TEST_CONN_ID] + class TestPostConnection(TestConnectionEndpoint): @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py index 7612edebadc5e..62699f58b9c37 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py @@ -449,7 +449,7 @@ def test_get_dags(self, test_client, query_params, expected_total_entries, expec assert actual_ids == expected_ids @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids") - def test_get_dags_should_call_authorized_dag_ids(self, mock_get_authorized_dag_ids, test_client): + def test_get_dags_should_call_get_authorized_dag_ids(self, mock_get_authorized_dag_ids, test_client): mock_get_authorized_dag_ids.return_value = {DAG1_ID, DAG2_ID} response = test_client.get("/dags") mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY, method="GET") diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py index 8112ed06d7634..d4d16abfb8f91 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from airflow.models.pool import Pool @@ -202,6 +204,18 @@ def test_should_respond_403(self, unauthorized_test_client): response = unauthorized_test_client.get("/pools", params={"pool_name_pattern": "~"}) assert response.status_code == 403 + @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_pools") + def test_should_call_get_authorized_pools(self, mock_get_authorized_pools, test_client): + self.create_pools() + mock_get_authorized_pools.return_value = {Pool.DEFAULT_POOL_NAME, POOL1_NAME} + response = test_client.get("/pools") + mock_get_authorized_pools.assert_called_once_with(user=mock.ANY, method="GET") + assert response.status_code == 200 + body = response.json() + + assert body["total_entries"] == 2 + assert [pool["name"] for pool in body["pools"]] == [Pool.DEFAULT_POOL_NAME, POOL1_NAME] + class TestPatchPool(TestPoolsEndpoint): @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py index 3b26fa8779411..076905ea03b85 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py @@ -308,6 +308,20 @@ def test_get_should_respond_403(self, unauthorized_test_client): response = unauthorized_test_client.get("/variables") assert response.status_code == 403 + @mock.patch( + "airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_variables" + ) + def test_should_call_get_authorized_variables(self, mock_get_authorized_variables, test_client): + self.create_variables() + mock_get_authorized_variables.return_value = {TEST_VARIABLE_KEY, TEST_VARIABLE_KEY2} + response = test_client.get("/variables") + mock_get_authorized_variables.assert_called_once_with(user=mock.ANY, method="GET") + assert response.status_code == 200 + body = response.json() + + assert body["total_entries"] == 2 + assert [variable["key"] for variable in body["variables"]] == [TEST_VARIABLE_KEY, TEST_VARIABLE_KEY2] + class TestPatchVariable(TestVariableEndpoint): @pytest.mark.enable_redact diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 387d968ec157f..9f2e73b297b56 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -273,6 +273,7 @@ def filter_authorized_dag_ids( dag_ids: set[str], user: AwsAuthManagerUser, method: ResourceMethod = "GET", + team_name: str | None = None, ): requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict) requests_list: list[IsAuthorizedRequest] = []