Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions airflow-core/docs/core-concepts/auth-manager/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ Optional methods recommended to override for optimization

The following methods aren't required to override to have a functional Airflow auth manager. However, it is recommended to override these to make your auth manager faster (and potentially less costly):

* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it will call ``is_authorized_connection`` for every single item.
* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it will call ``is_authorized_dag`` for every single item.
* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it will call ``is_authorized_pool`` for every single item.
* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it will call ``is_authorized_variable`` for every single item.
* ``get_authorized_dag_ids``: Return the list of Dag IDs the user has access to. If not overridden, it will call ``is_authorized_dag`` for every single Dag available in the environment.

* Note: To filter the results of ``get_authorized_dag_ids``, it is recommended that you define the filtering logic in your ``filter_authorized_dag_ids`` method. For example, this may be useful if you rely on per-Dag access controls derived from one or more fields on a given Dag (e.g. Dag tags).
* This method requires an active session with the Airflow metadata database. As such, overriding the ``get_authorized_dag_ids`` method is an advanced use case, which should be considered carefully -- it is recommended you refer to the :doc:`../../database-erd-ref`.
* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it calls ``is_authorized_connection`` for every single item.
* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it calls ``is_authorized_dag`` for every single item.
* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it calls ``is_authorized_pool`` for every single item.
* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it calls ``is_authorized_variable`` for every single item.
* ``filter_authorized_connections``: Given a list of connection IDs (``conn_id``), return the list of connection IDs the user has access to. If not overridden, it calls ``is_authorized_connection`` for every single connection passed as parameter.
* ``filter_authorized_dag_ids``: Given a list of Dag IDs, return the list of Dag IDs the user has access to. If not overridden, it calls ``is_authorized_dag`` for every single Dag passes as parameter.
* ``filter_authorized_pools``: Given a list of pool names, return the list of pool names the user has access to. If not overridden, it calls ``is_authorized_pool`` for every single pool passed as parameter.
* ``filter_authorized_variables``: Given a list of variable keys, return the list of variable keys the user has access to. If not overridden, it calls ``is_authorized_variable`` for every single variable passed as parameter.

CLI
^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@

import logging
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from jwt import InvalidTokenError
from sqlalchemy import select

from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
from airflow.api_fastapi.auth.managers.models.resource_details import BackfillDetails, DagDetails
from airflow.api_fastapi.auth.managers.models.resource_details import (
BackfillDetails,
ConnectionDetails,
DagDetails,
PoolDetails,
VariableDetails,
)
from airflow.api_fastapi.auth.tokens import (
JWTGenerator,
JWTValidator,
Expand All @@ -35,7 +42,9 @@
)
from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem
from airflow.configuration import conf
from airflow.models import DagModel
from airflow.models import Connection, DagModel, Pool, Variable
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team, dag_bundle_team_association_table
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -56,10 +65,7 @@
AssetAliasDetails,
AssetDetails,
ConfigurationDetails,
ConnectionDetails,
DagAccessEntity,
PoolDetails,
VariableDetails,
)
from airflow.cli.cli_config import CLICommand

Expand Down Expand Up @@ -427,36 +433,243 @@ def get_authorized_dag_ids(
"""
Get DAGs the user has access to.

By default, reads all the DAGs and check individually if the user has permissions to access the DAG.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.

:param user: the user
:param method: the method to filter on
:param session: the session
"""
dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
return self.filter_authorized_dag_ids(dag_ids=dag_ids, method=method, user=user)
stmt = (
select(DagModel.dag_id, Team.name)
.join(DagBundleModel, DagModel.bundle_name == DagBundleModel.name)
.join(
dag_bundle_team_association_table,
DagBundleModel.name == dag_bundle_team_association_table.c.dag_bundle_name,
isouter=True,
)
.join(Team, Team.id == dag_bundle_team_association_table.c.team_id, isouter=True)
)
rows = session.execute(stmt).all()
dags_by_team: dict[str | None, set[str]] = defaultdict(set)
for dag_id, team_name in rows:
dags_by_team[team_name].add(dag_id)

dag_ids: set[str] = set()
for team_name, team_dag_ids in dags_by_team.items():
dag_ids.update(
self.filter_authorized_dag_ids(
dag_ids=team_dag_ids, user=user, method=method, team_name=team_name
)
)

return dag_ids

def filter_authorized_dag_ids(
self,
*,
dag_ids: set[str],
user: T,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
"""
Filter DAGs the user has access to.

:param dag_ids: the list of DAG ids
By default, check individually if the user has permissions to access the DAG.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.

:param dag_ids: the set of DAG ids
:param user: the user
:param method: the method to filter on
:param team_name: the name of the team associated to the Dags if Airflow environment runs in
multi-team mode
"""

def _is_authorized_dag_id(dag_id: str):
return self.is_authorized_dag(
method=method, details=DagDetails(id=dag_id, team_name=team_name), user=user
)

return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(dag_id)}

@provide_session
def get_authorized_connections(
self,
*,
user: T,
method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
"""
Get connection ids (``conn_id``) the user has access to.

:param user: the user
:param method: the method to filter on
:param session: the session
"""
stmt = select(Connection.conn_id, Team.name).join(Team, Connection.team_id == Team.id, isouter=True)
rows = session.execute(stmt).all()
connections_by_team: dict[str | None, set[str]] = defaultdict(set)
for conn_id, team_name in rows:
connections_by_team[team_name].add(conn_id)

conn_ids: set[str] = set()
for team_name, team_conn_ids in connections_by_team.items():
conn_ids.update(
self.filter_authorized_connections(
conn_ids=team_conn_ids, user=user, method=method, team_name=team_name
)
)

return conn_ids

def filter_authorized_connections(
self,
*,
conn_ids: set[str],
user: T,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
"""
Filter connections the user has access to.

By default, check individually if the user has permissions to access the connection.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.

:param conn_ids: the set of connection ids (``conn_id``)
:param user: the user
:param method: the method to filter on
:param team_name: the name of the team associated to the connections if Airflow environment runs in
multi-team mode
"""

def _is_authorized_connection(conn_id: str):
return self.is_authorized_connection(
method=method, details=ConnectionDetails(conn_id=conn_id, team_name=team_name), user=user
)

return {conn_id for conn_id in conn_ids if _is_authorized_connection(conn_id)}

@provide_session
def get_authorized_variables(
self,
*,
user: T,
method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
"""
Get variable keys the user has access to.

def _is_authorized_dag_id(method: ResourceMethod, dag_id: str):
return self.is_authorized_dag(method=method, details=DagDetails(id=dag_id), user=user)
:param user: the user
:param method: the method to filter on
:param session: the session
"""
stmt = select(Variable.key, Team.name).join(Team, Variable.team_id == Team.id, isouter=True)
rows = session.execute(stmt).all()
variables_by_team: dict[str | None, set[str]] = defaultdict(set)
for var_key, team_name in rows:
variables_by_team[team_name].add(var_key)

var_keys: set[str] = set()
for team_name, team_var_keys in variables_by_team.items():
var_keys.update(
self.filter_authorized_variables(
variable_keys=team_var_keys, user=user, method=method, team_name=team_name
)
)

return var_keys

def filter_authorized_variables(
self,
*,
variable_keys: set[str],
user: T,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
"""
Filter variables the user has access to.

By default, check individually if the user has permissions to access the variable.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.

:param variable_keys: the set of variable keys
:param user: the user
:param method: the method to filter on
:param team_name: the name of the team associated to the connections if Airflow environment runs in
multi-team mode
"""

def _is_authorized_variable(var_key: str):
return self.is_authorized_variable(
method=method, details=VariableDetails(key=var_key, team_name=team_name), user=user
)

return {var_key for var_key in variable_keys if _is_authorized_variable(var_key)}

@provide_session
def get_authorized_pools(
self,
*,
user: T,
method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
"""
Get pools the user has access to.

:param user: the user
:param method: the method to filter on
:param session: the session
"""
stmt = select(Pool.pool, Team.name).join(Team, Pool.team_id == Team.id, isouter=True)
rows = session.execute(stmt).all()
pools_by_team: dict[str | None, set[str]] = defaultdict(set)
for pool_name, team_name in rows:
pools_by_team[team_name].add(pool_name)

pool_names: set[str] = set()
for team_name, team_pool_names in pools_by_team.items():
pool_names.update(
self.filter_authorized_pools(
pool_names=team_pool_names, user=user, method=method, team_name=team_name
)
)

return pool_names

def filter_authorized_pools(
self,
*,
pool_names: set[str],
user: T,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
"""
Filter pools the user has access to.

By default, check individually if the user has permissions to access the pool.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.

:param pool_names: the set of pool names
:param user: the user
:param method: the method to filter on
:param team_name: the name of the team associated to the connections if Airflow environment runs in
multi-team mode
"""

def _is_authorized_pool(name: str):
return self.is_authorized_pool(
method=method, details=PoolDetails(name=name, team_name=team_name), user=user
)

return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(method, dag_id)}
return {pool_name for pool_name in pool_names if _is_authorized_pool(pool_name)}

@staticmethod
def get_cli_commands() -> list[CLICommand]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
ConnectionTestResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_connection, requires_access_connection_bulk
from airflow.api_fastapi.core_api.security import (
ReadableConnectionsFilterDep,
requires_access_connection,
requires_access_connection_bulk,
)
from airflow.api_fastapi.core_api.services.public.connections import (
BulkConnectionService,
update_orm_from_pydantic,
Expand Down Expand Up @@ -117,13 +121,14 @@ def get_connections(
).dynamic_depends()
),
],
readable_connections_filter: ReadableConnectionsFilterDep,
session: SessionDep,
connection_id_pattern: QueryConnectionIdPatternSearch,
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
statement=select(Connection),
filters=[connection_id_pattern],
filters=[connection_id_pattern, readable_connections_filter],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
PoolResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_pool, requires_access_pool_bulk
from airflow.api_fastapi.core_api.security import (
ReadablePoolsFilterDep,
requires_access_pool,
requires_access_pool_bulk,
)
from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.models.pool import Pool
Expand Down Expand Up @@ -103,12 +107,13 @@ def get_pools(
Depends(SortParam(["id", "pool"], Pool, to_replace={"name": "pool"}).dynamic_depends()),
],
pool_name_pattern: QueryPoolNamePatternSearch,
readable_pools_filter: ReadablePoolsFilterDep,
session: SessionDep,
) -> PoolCollectionResponse:
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
statement=select(Pool),
filters=[pool_name_pattern],
filters=[pool_name_pattern, readable_pools_filter],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
Loading