diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 3611dcdee3dd6..8cc1cdfd0eb60 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3933,6 +3933,8 @@ paths: summary: Get Import Error description: Get an import error. operationId: get_import_error + security: + - OAuth2PasswordBearer: [] parameters: - name: import_error_id in: path @@ -3978,6 +3980,8 @@ paths: summary: Get Import Errors description: Get all import errors. operationId: get_import_errors + security: + - OAuth2PasswordBearer: [] parameters: - name: limit in: query diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index 01caf9048e2d6..4beb0ea2cd416 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -16,11 +16,19 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable, Sequence +from itertools import groupby +from operator import itemgetter from typing import Annotated from fastapi import Depends, HTTPException, status from sqlalchemy import select +from airflow.api_fastapi.app import get_auth_manager +from airflow.api_fastapi.auth.managers.models.batch_apis import IsAuthorizedDagRequest +from airflow.api_fastapi.auth.managers.models.resource_details import ( + DagDetails, +) from airflow.api_fastapi.common.db.common import ( SessionDep, paginated_select, @@ -36,18 +44,29 @@ ImportErrorResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ( + AccessView, + GetUserDep, + requires_access_view, +) +from airflow.models import DagModel from airflow.models.errors import ParseImportError +REDACTED_STACKTRACE = "REDACTED - you do not have read permission on all DAGs in the file" import_error_router = AirflowRouter(tags=["Import Error"], prefix="/importErrors") @import_error_router.get( "/{import_error_id}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[ + Depends(requires_access_view(AccessView.IMPORT_ERRORS)), + ], ) def get_import_error( import_error_id: int, session: SessionDep, + user: GetUserDep, ) -> ImportErrorResponse: """Get an import error.""" error = session.scalar(select(ParseImportError).where(ParseImportError.id == import_error_id)) @@ -56,12 +75,37 @@ def get_import_error( status.HTTP_404_NOT_FOUND, f"The ImportError with import_error_id: `{import_error_id}` was not found", ) + session.expunge(error) + + auth_manager = get_auth_manager() + can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user) + if can_read_all_dags: + # Early return if the user has access to all DAGs + return error + + readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user) + # We need file_dag_ids as a set for intersection, issubset operations + file_dag_ids = set( + session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == error.filename)).all() + ) + # Can the user read any DAGs in the file? + if not readable_dag_ids.intersection(file_dag_ids): + raise HTTPException( + status.HTTP_403_FORBIDDEN, + "You do not have read permission on any of the DAGs in the file", + ) + # Check if user has read access to all the DAGs defined in the file + if not file_dag_ids.issubset(readable_dag_ids): + error.stacktrace = REDACTED_STACKTRACE return error @import_error_router.get( "", + dependencies=[ + Depends(requires_access_view(AccessView.IMPORT_ERRORS)), + ], ) def get_import_errors( limit: QueryLimit, @@ -83,6 +127,7 @@ def get_import_errors( ), ], session: SessionDep, + user: GetUserDep, ) -> ImportErrorCollectionResponse: """Get all import errors.""" import_errors_select, total_entries = paginated_select( @@ -92,7 +137,58 @@ def get_import_errors( limit=limit, session=session, ) - import_errors = session.scalars(import_errors_select) + + auth_manager = get_auth_manager() + can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user) + if can_read_all_dags: + # Early return if the user has access to all DAGs + import_errors = session.scalars(import_errors_select).all() + return ImportErrorCollectionResponse( + import_errors=import_errors, + total_entries=total_entries, + ) + + # if the user doesn't have access to all DAGs, only display errors from visible DAGs + readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user) + # Build a cte that fetches dag_ids for each file location + visiable_files_cte = ( + select(DagModel.fileloc, DagModel.dag_id).where(DagModel.dag_id.in_(readable_dag_ids)).cte() + ) + + # Prepare the import errors query by joining with the cte. + # Each returned row will be a tuple: (ParseImportError, dag_id) + import_errors_stmt = ( + select(ParseImportError, visiable_files_cte.c.dag_id) + .join(visiable_files_cte, ParseImportError.filename == visiable_files_cte.c.fileloc) + .order_by(ParseImportError.id) + ) + + # Paginate the import errors query + import_errors_select, total_entries = paginated_select( + statement=import_errors_stmt, + order_by=order_by, + offset=offset, + limit=limit, + session=session, + ) + import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = groupby( + session.execute(import_errors_select), itemgetter(0) + ) + + import_errors = [] + for import_error, file_dag_ids in import_errors_result: + # Check if user has read access to all the DAGs defined in the file + requests: Sequence[IsAuthorizedDagRequest] = [ + { + "method": "GET", + "details": DagDetails(id=dag_id), + } + for dag_id in file_dag_ids + ] + if not auth_manager.batch_is_authorized_dag(requests, user=user): + session.expunge(import_error) + import_error.stacktrace = REDACTED_STACKTRACE + import_errors.append(import_error) return ImportErrorCollectionResponse( import_errors=import_errors, diff --git a/tests/api_fastapi/core_api/routes/public/test_import_error.py b/tests/api_fastapi/core_api/routes/public/test_import_error.py index 38965ab7184c1..84d7871585b7d 100644 --- a/tests/api_fastapi/core_api/routes/public/test_import_error.py +++ b/tests/api_fastapi/core_api/routes/public/test_import_error.py @@ -17,15 +17,21 @@ from __future__ import annotations from datetime import datetime, timezone +from typing import TYPE_CHECKING +from unittest import mock import pytest +from airflow.models import DagModel from airflow.models.errors import ParseImportError -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session -from tests_common.test_utils.db import clear_db_import_errors +from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors from tests_common.test_utils.format_datetime import from_datetime_to_zulu_without_ms +if TYPE_CHECKING: + from sqlalchemy.orm import Session + pytestmark = pytest.mark.db_test FILENAME1 = "test_filename1.py" @@ -42,95 +48,171 @@ BUNDLE_NAME = "dag_maker" -class TestImportErrorEndpoint: - """Common class for /public/importErrors related unit tests.""" +@pytest.fixture(scope="class") +@provide_session +def permitted_dag_model(session: Session = NEW_SESSION) -> DagModel: + dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id1", is_paused=False) + session.add(dag_model) + session.commit() + return dag_model - @staticmethod - def _clear_db(): - clear_db_import_errors() - @pytest.fixture(autouse=True) - @provide_session - def setup(self, session=None) -> dict[str, ParseImportError]: - """ - Setup method which is run before every test. - """ - self._clear_db() - import_error1 = ParseImportError( - bundle_name=BUNDLE_NAME, - filename=FILENAME1, - stacktrace=STACKTRACE1, - timestamp=TIMESTAMP1, - ) - import_error2 = ParseImportError( +@pytest.fixture(scope="class") +@provide_session +def not_permitted_dag_model(session: Session = NEW_SESSION) -> DagModel: + dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id4", is_paused=False) + session.add(dag_model) + session.commit() + return dag_model + + +@pytest.fixture(scope="class", autouse=True) +def clear_db(): + clear_db_import_errors() + clear_db_dags() + + yield + + clear_db_import_errors() + clear_db_dags() + + +@pytest.fixture(autouse=True, scope="class") +@provide_session +def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]: + _import_errors = [ + ParseImportError( bundle_name=BUNDLE_NAME, - filename=FILENAME2, - stacktrace=STACKTRACE2, - timestamp=TIMESTAMP2, + filename=filename, + stacktrace=stacktrace, + timestamp=timestamp, ) - import_error3 = ParseImportError( - bundle_name=BUNDLE_NAME, - filename=FILENAME3, - stacktrace=STACKTRACE3, - timestamp=TIMESTAMP3, + for filename, stacktrace, timestamp in zip( + (FILENAME1, FILENAME2, FILENAME3), + (STACKTRACE1, STACKTRACE2, STACKTRACE3), + (TIMESTAMP1, TIMESTAMP2, TIMESTAMP3), ) - session.add_all([import_error1, import_error2, import_error3]) - session.commit() - return {FILENAME1: import_error1, FILENAME2: import_error2, FILENAME3: import_error3} + ] + + session.add_all(_import_errors) + return _import_errors + - def teardown_method(self) -> None: - self._clear_db() +def set_mock_auth_manager__is_authorized_dag( + mock_auth_manager: mock.Mock, is_authorized_dag_return_value: bool = False +) -> mock.Mock: + mock_is_authorized_dag = mock_auth_manager.return_value.is_authorized_dag + mock_is_authorized_dag.return_value = is_authorized_dag_return_value + return mock_is_authorized_dag -class TestGetImportError(TestImportErrorEndpoint): +def set_mock_auth_manager__get_authorized_dag_ids( + mock_auth_manager: mock.Mock, get_authorized_dag_ids_return_value: set[str] | None = None +) -> mock.Mock: + if get_authorized_dag_ids_return_value is None: + get_authorized_dag_ids_return_value = set() + mock_get_authorized_dag_ids = mock_auth_manager.return_value.get_authorized_dag_ids + mock_get_authorized_dag_ids.return_value = get_authorized_dag_ids_return_value + return mock_get_authorized_dag_ids + + +def set_mock_auth_manager__batch_is_authorized_dag( + mock_auth_manager: mock.Mock, batch_is_authorized_dag_return_value: bool = False +) -> mock.Mock: + mock_batch_is_authorized_dag = mock_auth_manager.return_value.batch_is_authorized_dag + mock_batch_is_authorized_dag.return_value = batch_is_authorized_dag_return_value + return mock_batch_is_authorized_dag + + +class TestGetImportError: @pytest.mark.parametrize( - "import_error_key, expected_status_code, expected_body", + "prepared_import_error_idx, expected_status_code, expected_body", [ ( - FILENAME1, + 0, 200, { - "import_error_id": 1, - "timestamp": TIMESTAMP1, + "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1), "filename": FILENAME1, "stack_trace": STACKTRACE1, "bundle_name": BUNDLE_NAME, }, ), ( - FILENAME2, + 1, 200, { - "import_error_id": 2, - "timestamp": TIMESTAMP2, + "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP2), "filename": FILENAME2, "stack_trace": STACKTRACE2, "bundle_name": BUNDLE_NAME, }, ), - (IMPORT_ERROR_NON_EXISTED_KEY, 404, {}), + (None, 404, {}), ], ) def test_get_import_error( - self, test_client, setup, import_error_key, expected_status_code, expected_body + self, prepared_import_error_idx, expected_status_code, expected_body, test_client, import_errors ): - import_error: ParseImportError | None = setup.get(import_error_key) + import_error: ParseImportError | None = ( + import_errors[prepared_import_error_idx] if prepared_import_error_idx is not None else None + ) import_error_id = import_error.id if import_error else IMPORT_ERROR_NON_EXISTED_ID response = test_client.get(f"/public/importErrors/{import_error_id}") assert response.status_code == expected_status_code if expected_status_code != 200: return - expected_json = { + + expected_body.update({"import_error_id": import_error_id}) + assert response.json() == expected_body + + def test_should_raises_401_unauthenticated(self, unauthenticated_test_client, import_errors): + import_error_id = import_errors[0].id + response = unauthenticated_test_client.get(f"/public/importErrors/{import_error_id}") + assert response.status_code == 401 + + def test_should_raises_403_unauthorized(self, unauthorized_test_client, import_errors): + import_error_id = import_errors[0].id + response = unauthorized_test_client.get(f"/public/importErrors/{import_error_id}") + assert response.status_code == 403 + + @mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager") + def test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file( + self, mock_get_auth_manager, test_client, import_errors + ): + import_error_id = import_errors[0].id + # Mock auth_manager + mock_is_authorized_dag = set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager) + mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager) + # Act + response = test_client.get(f"/public/importErrors/{import_error_id}") + # Assert + mock_is_authorized_dag.assert_called_once_with(method="GET", user=mock.ANY) + mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY) + assert response.status_code == 403 + assert response.json() == {"detail": "You do not have read permission on any of the DAGs in the file"} + + @mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager") + def test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_file( + self, mock_get_auth_manager, test_client, permitted_dag_model, not_permitted_dag_model, import_errors + ): + import_error_id = import_errors[0].id + set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager) + set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, {permitted_dag_model.dag_id}) + # Act + response = test_client.get(f"/public/importErrors/{import_error_id}") + # Assert + assert response.status_code == 200 + assert response.json() == { "import_error_id": import_error_id, - "timestamp": from_datetime_to_zulu_without_ms(expected_body["timestamp"]), - "filename": expected_body["filename"], - "stack_trace": expected_body["stack_trace"], + "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1), + "filename": FILENAME1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", "bundle_name": BUNDLE_NAME, } - assert response.json() == expected_json -class TestGetImportErrors(TestImportErrorEndpoint): +class TestGetImportErrors: @pytest.mark.parametrize( "query_params, expected_status_code, expected_total_entries, expected_filenames", [ @@ -225,3 +307,59 @@ def test_get_import_errors( assert [ import_error["filename"] for import_error in response_json["import_errors"] ] == expected_filenames + + def test_should_raises_401_unauthenticated(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/importErrors") + assert response.status_code == 401 + + def test_should_raises_403_unauthorized(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/importErrors") + assert response.status_code == 403 + + @pytest.mark.parametrize( + "batch_is_authorized_dag_return_value, expected_stack_trace", + [ + pytest.param(True, STACKTRACE1, id="user_has_read_access_to_all_dags_in_current_file"), + pytest.param( + False, + "REDACTED - you do not have read permission on all DAGs in the file", + id="user_does_not_have_read_access_to_all_dags_in_current_file", + ), + ], + ) + @pytest.mark.usefixtures("permitted_dag_model") + @mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager") + def test_user_can_not_read_all_dags_in_file( + self, + mock_get_auth_manager, + test_client, + batch_is_authorized_dag_return_value, + expected_stack_trace, + permitted_dag_model, + import_errors, + ): + set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager) + mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids( + mock_get_auth_manager, {permitted_dag_model.dag_id} + ) + set_mock_auth_manager__batch_is_authorized_dag( + mock_get_auth_manager, batch_is_authorized_dag_return_value + ) + # Act + response = test_client.get("/public/importErrors") + # Assert + mock_get_authorized_dag_ids.assert_called_once_with(method="GET", user=mock.ANY) + assert response.status_code == 200 + response_json = response.json() + assert response_json == { + "total_entries": 1, + "import_errors": [ + { + "import_error_id": import_errors[0].id, + "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1), + "filename": FILENAME1, + "stack_trace": expected_stack_trace, + "bundle_name": BUNDLE_NAME, + } + ], + }