Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def get_authorized_dag_ids(
*,
user: T,
method: ResourceMethod = "GET",
access_entity: DagAccessEntity | None = None,
session: Session = NEW_SESSION,
) -> set[str]:
"""
Expand All @@ -355,28 +356,36 @@ def get_authorized_dag_ids(

:param user: the user
:param method: the method to filter on
:param access_entity: the kind of DAG information the user wants to access.
:param session: the session
"""
dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
return self.filter_authorized_dag_ids(dag_ids=dag_ids, method=method, user=user)
return self.filter_authorized_dag_ids(
dag_ids=dag_ids, method=method, access_entity=access_entity, user=user
)

def filter_authorized_dag_ids(
self,
*,
dag_ids: set[str],
user: T,
method: ResourceMethod = "GET",
access_entity: DagAccessEntity | None = None,
) -> set[str]:
"""
Filter DAGs the user has access to.

:param dag_ids: the list of DAG ids
:param user: the user
:param method: the method to filter on
:param access_entity: the kind of DAG information the authorization request is about.
If not provided, the authorization request is about the DAG itself
"""

def _is_authorized_dag_id(method: ResourceMethod, dag_id: str):
return self.is_authorized_dag(method=method, details=DagDetails(id=dag_id), user=user)
return self.is_authorized_dag(
method=method, details=DagDetails(id=dag_id), access_entity=access_entity, user=user
)

return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(method, dag_id)}

Expand Down
10 changes: 0 additions & 10 deletions airflow-core/src/airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ async def paginated_select_async(
if return_total_entries:
total_entries = await get_query_count_async(statement, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_authorized_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

statement = apply_filters_to_select(
statement=statement,
filters=[order_by, offset, limit],
Expand Down Expand Up @@ -171,11 +166,6 @@ def paginated_select(
if return_total_entries:
total_entries = get_query_count(statement, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_authorized_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

statement = apply_filters_to_select(statement=statement, filters=[order_by, offset, limit])

return statement, total_entries
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def get_dag_asset_queued_event(
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[
Depends(requires_access_asset(method="DELETE")),
Depends(requires_access_dag(method="GET")),
Depends(action_logging()),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from typing import cast

from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, status

from airflow import settings
from airflow.api_fastapi.common.router import AirflowRouter
Expand All @@ -32,7 +32,6 @@
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import (
ReadableDagsFilterDep,
requires_access_dag,
)
from airflow.models.dagbag import DagBag

Expand All @@ -46,7 +45,7 @@
status.HTTP_400_BAD_REQUEST,
]
),
dependencies=[Depends(requires_access_dag(method="GET"))],
# No authorization access is performed on the API level because `ReadableDagsFilterDep` filters Dags accessible by the user only
)
def get_dag_reports(
subdir: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def clear_dag_run(
@dag_run_router.get(
"",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))],
# No authorization access is performed on the API level because `ReadableDagRunsFilterDep` filters Dags accessible by the user only
)
def get_dag_runs(
dag_id: str,
Expand Down Expand Up @@ -502,7 +502,7 @@ def wait_dag_run_until_finished(
@dag_run_router.post(
"/list",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))],
# No authorization access is performed on the API level because `ReadableDagRunsFilterDep` filters Dags accessible by the user only
)
def get_list_dag_runs_batch(
dag_id: Literal["~"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from fastapi import Depends, status

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
Expand All @@ -39,7 +38,7 @@
DagStatsStateResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep
from airflow.models.dagrun import DagRun
from airflow.utils.state import DagRunState

Expand All @@ -54,7 +53,7 @@
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))],
# No authorization access is performed on the API level because `ReadableDagRunsFilterDep` filters Dags accessible by the user only
)
def get_dag_stats(
readable_dag_runs_filter: ReadableDagRunsFilterDep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.dag_tags import DAGTagCollectionResponse
from airflow.api_fastapi.core_api.security import ReadableTagsFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.security import ReadableTagsFilterDep
from airflow.models.dag import DagTag

dag_tags_router = AirflowRouter(tags=["DAG"], prefix="/dagTags")


@dag_tags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))])
@dag_tags_router.get(
"",
# No authorization access is performed on the API level because `ReadableTagsFilterDep` filters Dags accessible by the user only
)
def get_dag_tags(
limit: QueryLimit,
offset: QueryOffset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from fastapi import Depends
from sqlalchemy import select

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
Expand All @@ -38,15 +37,15 @@
from airflow.api_fastapi.core_api.datamodels.dag_warning import (
DAGWarningCollectionResponse,
)
from airflow.api_fastapi.core_api.security import ReadableDagWarningsFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.security import ReadableDagWarningsFilterDep
from airflow.models.dagwarning import DagWarning, DagWarningType

dag_warning_router = AirflowRouter(tags=["DagWarning"])


@dag_warning_router.get(
"/dagWarnings",
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.WARNING))],
# No authorization access is performed on the API level because `ReadableDagWarningsFilterDep` filters Dags accessible by the user only
)
def list_dag_warnings(
dag_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(DagWarning.dag_id, str | None))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def patch_dag(
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())],
# No authorization access is performed on the API level because `EditableDagsFilterDep` filters Dags accessible by the user only
dependencies=[Depends(action_logging())],
)
def patch_dags(
patch_body: DAGPatchBody,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
UpdateHITLDetailPayload,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import GetUserDep, ReadableTIFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.security import (
GetUserDep,
ReadableHITLFilterDep,
requires_access_dag,
)
from airflow.models.hitl import HITLDetail as HITLDetailModel
from airflow.models.taskinstance import TaskInstance as TI

Expand Down Expand Up @@ -274,7 +278,7 @@ def get_mapped_ti_hitl_detail(
@hitl_router.get(
"/",
status_code=status.HTTP_200_OK,
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.HITL_DETAIL))],
# No authorization access is performed on the API level because `ReadableHITLFilterDep` filters Dags accessible by the user only
)
def get_hitl_details(
limit: QueryLimit,
Expand All @@ -300,7 +304,7 @@ def get_hitl_details(
],
session: SessionDep,
# ti related filter
readable_ti_filter: ReadableTIFilterDep,
readable_ti_filter: ReadableHITLFilterDep,
dag_id_pattern: QueryHITLDetailDagIdPatternSearch,
dag_run_id: QueryHITLDetailDagRunIdFilter,
task_id: QueryHITLDetailTaskIdPatternSearch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def get_task_instances(
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[
Depends(action_logging()),
Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE)),
# No authorization access is performed on the API level because `ReadableTIFilterDep` filters Dags accessible by the user only
],
)
def get_task_instances_batch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@
@dags_router.get(
"",
response_model_exclude_none=True,
dependencies=[
Depends(requires_access_dag(method="GET")),
Depends(requires_access_dag("GET", DagAccessEntity.RUN)),
],
# No authorization access is performed on the API level because `ReadableDagsFilterDep` filters Dags accessible by the user only
operation_id="get_dags_ui",
)
def get_dags(
Expand Down
29 changes: 20 additions & 9 deletions airflow-core/src/airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def to_orm(self, select: Select) -> Select:


def permitted_dag_filter_factory(
method: ResourceMethod, filter_class=PermittedDagFilter
method: ResourceMethod, access_entity: DagAccessEntity | None = None, filter_class=PermittedDagFilter
) -> Callable[[Request, BaseUser], PermittedDagFilter]:
"""
Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user.

:param method: whether filter readable or writable.
:param method: the method to filter on
:param access_entity: the Dag sub entity to filter on. If not provided, filter on Dag level.
:param filter_class: the class to filter on. If not provided, filter on Dag level
:return: The callable that can be used as Depends in FastAPI.
"""

Expand All @@ -170,7 +172,9 @@ def depends_permitted_dags_filter(
user: GetUserDep,
) -> PermittedDagFilter:
auth_manager: BaseAuthManager = request.app.state.auth_manager
authorized_dags: set[str] = auth_manager.get_authorized_dag_ids(user=user, method=method)
authorized_dags: set[str] = auth_manager.get_authorized_dag_ids(
user=user, method=method, access_entity=access_entity
)
return filter_class(authorized_dags)

return depends_permitted_dags_filter
Expand All @@ -179,20 +183,27 @@ def depends_permitted_dags_filter(
EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("PUT"))]
ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("GET"))]
ReadableDagRunsFilterDep = Annotated[
PermittedDagRunFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagRunFilter))
PermittedDagRunFilter,
Depends(permitted_dag_filter_factory("GET", DagAccessEntity.RUN, PermittedDagRunFilter)),
]
ReadableDagWarningsFilterDep = Annotated[
PermittedDagWarningFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagWarningFilter))
PermittedDagWarningFilter,
Depends(permitted_dag_filter_factory("GET", DagAccessEntity.WARNING, PermittedDagWarningFilter)),
]
ReadableHITLFilterDep = Annotated[
PermittedTIFilter,
Depends(permitted_dag_filter_factory("GET", DagAccessEntity.HITL_DETAIL, PermittedTIFilter)),
]
ReadableTIFilterDep = Annotated[
PermittedTIFilter, Depends(permitted_dag_filter_factory("GET", PermittedTIFilter))
PermittedTIFilter,
Depends(permitted_dag_filter_factory("GET", DagAccessEntity.TASK_INSTANCE, PermittedTIFilter)),
]
ReadableXComFilterDep = Annotated[
PermittedXComFilter, Depends(permitted_dag_filter_factory("GET", PermittedXComFilter))
PermittedXComFilter,
Depends(permitted_dag_filter_factory("GET", DagAccessEntity.XCOM, PermittedXComFilter)),
]

ReadableTagsFilterDep = Annotated[
PermittedTagFilter, Depends(permitted_dag_filter_factory("GET", PermittedTagFilter))
PermittedTagFilter, Depends(permitted_dag_filter_factory("GET", None, PermittedTagFilter))
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,25 @@ def filter_authorized_dag_ids(
dag_ids: set[str],
user: AwsAuthManagerUser,
method: ResourceMethod = "GET",
access_entity: DagAccessEntity | None = None,
):
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for dag_id in dag_ids:
context = (
None
if access_entity is None
else {
"dag_entity": {
"string": access_entity.value,
},
}
)
request: IsAuthorizedRequest = {
"method": method,
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
"context": context,
}
requests[dag_id][method] = request
requests_list.append(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,24 @@ def get_authorized_dag_ids(
*,
user: User,
method: ResourceMethod = "GET",
access_entity: DagAccessEntity | None = None,
session: Session = NEW_SESSION,
) -> set[str]:
if self._is_authorized(method=method, resource_type=RESOURCE_DAG, user=user):
# If user is authorized to access all DAGs, return all DAGs
return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
if isinstance(user, AnonymousUser):
return set()
if access_entity:
resource_types = self._get_fab_resource_types(access_entity)
if not all(
self._is_authorized(method=method, resource_type=resource_type, user=user)
for resource_type in resource_types
):
# If `access_entity` is provided and the user is not authorized to access this given `access_entity`, return empty set
return set()

dag_method: ResourceMethod = "GET" if method == "GET" else "PUT"
user_query = session.scalar(
select(User)
.options(
Expand All @@ -452,15 +463,13 @@ def get_authorized_dag_ids(
action = permission.action.name
if (
action in map_fab_action_name_to_method_name
and map_fab_action_name_to_method_name[action] == method
and map_fab_action_name_to_method_name[action] == dag_method
):
resource = permission.resource.name
if resource == permissions.RESOURCE_DAG:
return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
if resource.startswith(permissions.RESOURCE_DAG_PREFIX):
resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :])
else:
resources.add(resource)
return set(session.scalars(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))))

@cached_property
Expand Down
Loading