Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from collections.abc import Iterable
from datetime import datetime

from pydantic import Field
Expand All @@ -36,5 +37,5 @@ class ImportErrorResponse(BaseModel):
class ImportErrorCollectionResponse(BaseModel):
"""Import Error Collection Response."""

import_errors: list[ImportErrorResponse]
import_errors: Iterable[ImportErrorResponse]
total_entries: int
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,6 @@ def get_import_error(
session.expunge(error)

auth_manager = get_auth_manager()
can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
if can_read_all_dags:
# Early return if the user has access to all DAGs
return error

readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
# We need file_dag_ids as a set for intersection, issubset operations
file_dag_ids = set(
Expand Down Expand Up @@ -132,26 +127,7 @@ def get_import_errors(
user: GetUserDep,
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
statement=select(ParseImportError),
filters=[filename_pattern],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

auth_manager = get_auth_manager()
can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
if can_read_all_dags:
# Early return if the user has access to all DAGs
import_errors = session.scalars(import_errors_select).all()
return ImportErrorCollectionResponse(
import_errors=import_errors,
total_entries=total_entries,
)

# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user)
# Build a cte that fetches dag_ids for each file location
visible_files_cte = (
Expand Down Expand Up @@ -183,7 +159,7 @@ def get_import_errors(
limit=limit,
session=session,
)
import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = groupby(
import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = groupby(
session.execute(import_errors_select), itemgetter(0)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,37 @@ def permitted_dag_model(testing_dag_bundle, session: Session = NEW_SESSION) -> D
return dag_model


@pytest.fixture
@provide_session
def permitted_dag_model_all(testing_dag_bundle, session: Session = NEW_SESSION) -> set[str]:
dag_model1 = DagModel(
fileloc=FILENAME1,
relative_fileloc=FILENAME1,
dag_id="dag_id1",
is_paused=False,
bundle_name=BUNDLE_NAME,
)
dag_model2 = DagModel(
fileloc=FILENAME2,
relative_fileloc=FILENAME2,
dag_id="dag_id2",
is_paused=False,
bundle_name=BUNDLE_NAME,
)
dag_model3 = DagModel(
fileloc=FILENAME3,
relative_fileloc=FILENAME3,
dag_id="dag_id3",
is_paused=False,
bundle_name=BUNDLE_NAME,
)
session.add(dag_model1)
session.add(dag_model2)
session.add(dag_model3)
session.commit()
return {dag_model1.dag_id, dag_model2.dag_id, dag_model3.dag_id}


@pytest.fixture
@provide_session
def not_permitted_dag_model(testing_dag_bundle, session: Session = NEW_SESSION) -> DagModel:
Expand Down Expand Up @@ -104,7 +135,7 @@ def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]:
timestamp=timestamp,
)
for bundle, filename, stacktrace, timestamp in zip(
(BUNDLE_NAME, BUNDLE_NAME, None),
(BUNDLE_NAME, BUNDLE_NAME, BUNDLE_NAME),
(FILENAME1, FILENAME2, FILENAME3),
(STACKTRACE1, STACKTRACE2, STACKTRACE3),
(TIMESTAMP1, TIMESTAMP2, TIMESTAMP3),
Expand All @@ -115,14 +146,6 @@ def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]:
return _import_errors


def set_mock_auth_manager__is_authorized_dag(
mock_auth_manager: mock.Mock, is_authorized_dag_return_value: bool = False
) -> mock.Mock:
mock_is_authorized_dag = mock_auth_manager.return_value.is_authorized_dag
mock_is_authorized_dag.return_value = is_authorized_dag_return_value
return mock_is_authorized_dag


def set_mock_auth_manager__get_authorized_dag_ids(
mock_auth_manager: mock.Mock, get_authorized_dag_ids_return_value: set[str] | None = None
) -> mock.Mock:
Expand Down Expand Up @@ -172,19 +195,28 @@ class TestGetImportError:
"timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP3),
"filename": FILENAME3,
"stack_trace": STACKTRACE3,
"bundle_name": None,
"bundle_name": BUNDLE_NAME,
},
),
(None, 404, {}),
],
)
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_error(
self, prepared_import_error_idx, expected_status_code, expected_body, test_client, import_errors
self,
mock_get_auth_manager,
prepared_import_error_idx,
expected_status_code,
expected_body,
test_client,
permitted_dag_model_all,
import_errors,
):
import_error: ParseImportError | None = (
import_errors[prepared_import_error_idx] if prepared_import_error_idx is not None else None
)
import_error_id = import_error.id if import_error else IMPORT_ERROR_NON_EXISTED_ID
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
response = test_client.get(f"/importErrors/{import_error_id}")
assert response.status_code == expected_status_code
if expected_status_code != 200:
Expand All @@ -209,23 +241,25 @@ def test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
):
import_error_id = import_errors[0].id
# Mock auth_manager
mock_is_authorized_dag = set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
# Act
response = test_client.get(f"/importErrors/{import_error_id}")
# Assert
mock_is_authorized_dag.assert_called_once_with(method="GET", user=mock.ANY)
mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
assert response.status_code == 403
assert response.json() == {"detail": "You do not have read permission on any of the DAGs in the file"}

@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_file(
self, mock_get_auth_manager, test_client, permitted_dag_model, not_permitted_dag_model, import_errors
self,
mock_get_auth_manager,
test_client,
permitted_dag_model_all,
not_permitted_dag_model,
import_errors,
):
import_error_id = import_errors[0].id
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, {permitted_dag_model.dag_id})
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
# Act
response = test_client.get(f"/importErrors/{import_error_id}")
# Assert
Expand Down Expand Up @@ -315,15 +349,21 @@ class TestGetImportErrors:
),
],
)
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_errors(
self,
mock_get_auth_manager,
test_client,
query_params,
expected_status_code,
expected_total_entries,
expected_filenames,
permitted_dag_model_all,
):
with assert_queries_count(2):
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)

with assert_queries_count(5):
response = test_client.get("/importErrors", params=query_params)

assert response.status_code == expected_status_code
Expand Down Expand Up @@ -366,7 +406,6 @@ def test_user_can_not_read_all_dags_in_file(
permitted_dag_model,
import_errors,
):
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
)
Expand Down Expand Up @@ -399,7 +438,6 @@ def test_bundle_name_join_condition_for_import_errors(
self, mock_get_auth_manager, test_client, permitted_dag_model, import_errors, session
):
"""Test that the bundle_name join condition works correctly."""
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
)
Expand Down