Skip to content
4 changes: 4 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3933,6 +3933,8 @@ paths:
summary: Get Import Error
description: Get an import error.
operationId: get_import_error
security:
- OAuth2PasswordBearer: []
parameters:
- name: import_error_id
in: path
Expand Down Expand Up @@ -3978,6 +3980,8 @@ paths:
summary: Get Import Errors
description: Get all import errors.
operationId: get_import_errors
security:
- OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
Expand Down
98 changes: 97 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@
# under the License.
from __future__ import annotations

from collections.abc import Iterable, Sequence
from itertools import groupby
from operator import itemgetter
from typing import Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import select

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.models.batch_apis import IsAuthorizedDagRequest
from airflow.api_fastapi.auth.managers.models.resource_details import (
DagDetails,
)
from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
Expand All @@ -36,18 +44,29 @@
ImportErrorResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import (
AccessView,
GetUserDep,
requires_access_view,
)
from airflow.models import DagModel
from airflow.models.errors import ParseImportError

REDACTED_STACKTRACE = "REDACTED - you do not have read permission on all DAGs in the file"
import_error_router = AirflowRouter(tags=["Import Error"], prefix="/importErrors")


@import_error_router.get(
"/{import_error_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[
Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
],
)
def get_import_error(
import_error_id: int,
session: SessionDep,
user: GetUserDep,
) -> ImportErrorResponse:
"""Get an import error."""
error = session.scalar(select(ParseImportError).where(ParseImportError.id == import_error_id))
Expand All @@ -56,12 +75,37 @@ def get_import_error(
status.HTTP_404_NOT_FOUND,
f"The ImportError with import_error_id: `{import_error_id}` was not found",
)
session.expunge(error)

auth_manager = get_auth_manager()
can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
if can_read_all_dags:
# Early return if the user has access to all DAGs
return error

readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
# We need file_dag_ids as a set for intersection, issubset operations
file_dag_ids = set(
session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == error.filename)).all()
)
# Can the user read any DAGs in the file?
if not readable_dag_ids.intersection(file_dag_ids):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
"You do not have read permission on any of the DAGs in the file",
)

# Check if user has read access to all the DAGs defined in the file
if not file_dag_ids.issubset(readable_dag_ids):
error.stacktrace = REDACTED_STACKTRACE
return error


@import_error_router.get(
"",
dependencies=[
Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
],
)
def get_import_errors(
limit: QueryLimit,
Expand All @@ -83,6 +127,7 @@ def get_import_errors(
),
],
session: SessionDep,
user: GetUserDep,
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
Expand All @@ -92,7 +137,58 @@ def get_import_errors(
limit=limit,
session=session,
)
import_errors = session.scalars(import_errors_select)

auth_manager = get_auth_manager()
can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
if can_read_all_dags:
# Early return if the user has access to all DAGs
import_errors = session.scalars(import_errors_select).all()
return ImportErrorCollectionResponse(
import_errors=import_errors,
total_entries=total_entries,
)

# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user)
# Build a cte that fetches dag_ids for each file location
visiable_files_cte = (
select(DagModel.fileloc, DagModel.dag_id).where(DagModel.dag_id.in_(readable_dag_ids)).cte()
)

# Prepare the import errors query by joining with the cte.
# Each returned row will be a tuple: (ParseImportError, dag_id)
import_errors_stmt = (
select(ParseImportError, visiable_files_cte.c.dag_id)
.join(visiable_files_cte, ParseImportError.filename == visiable_files_cte.c.fileloc)
.order_by(ParseImportError.id)
)

# Paginate the import errors query
import_errors_select, total_entries = paginated_select(
statement=import_errors_stmt,
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = groupby(
session.execute(import_errors_select), itemgetter(0)
)

import_errors = []
for import_error, file_dag_ids in import_errors_result:
# Check if user has read access to all the DAGs defined in the file
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id),
}
for dag_id in file_dag_ids
]
if not auth_manager.batch_is_authorized_dag(requests, user=user):
session.expunge(import_error)
import_error.stacktrace = REDACTED_STACKTRACE
import_errors.append(import_error)

return ImportErrorCollectionResponse(
import_errors=import_errors,
Expand Down
Loading