diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py index 8909b5ff1bc64..529fbb94a6be4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -43,7 +43,11 @@ EventLogResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import DagAccessEntity, requires_access_dag +from airflow.api_fastapi.core_api.security import ( + DagAccessEntity, + ReadableEventLogsFilterDep, + requires_access_dag, +) from airflow.models import Log event_logs_router = AirflowRouter(tags=["Event Log"], prefix="/eventLogs") @@ -126,6 +130,7 @@ def get_event_logs( run_id_pattern: Annotated[_SearchParam, Depends(search_param_factory(Log.run_id, "run_id_pattern"))], owner_pattern: Annotated[_SearchParam, Depends(search_param_factory(Log.owner, "owner_pattern"))], event_pattern: Annotated[_SearchParam, Depends(search_param_factory(Log.event, "event_pattern"))], + readable_event_logs_filter: ReadableEventLogsFilterDep, ) -> EventLogCollectionResponse: """Get all Event Logs.""" query = select(Log).options(joinedload(Log.task_instance), joinedload(Log.dag_model)) @@ -151,6 +156,8 @@ def get_event_logs( run_id_pattern, owner_pattern, event_pattern, + # Permission + readable_event_logs_filter, ], offset=offset, limit=limit, diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index e17f92776da12..1db6749d1f532 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -25,6 +25,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, OAuth2PasswordBearer from jwt import ExpiredSignatureError, InvalidTokenError from pydantic import NonNegativeInt +from sqlalchemy import or_ from airflow.api_fastapi.app import get_auth_manager from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN @@ -62,6 +63,7 @@ from airflow.models import Connection, Pool, Variable from airflow.models.dag import DagModel, DagRun, DagTag from airflow.models.dagwarning import DagWarning +from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance as TI from airflow.models.xcom import XComModel @@ -170,6 +172,15 @@ def to_orm(self, select: Select) -> Select: return select.where(DagWarning.dag_id.in_(self.value)) +class PermittedEventLogFilter(PermittedDagFilter): + """A parameter that filters the permitted even logs for the user.""" + + def to_orm(self, select: Select) -> Select: + # Event Logs not related to Dags have dag_id as None and are always returned. + # return select.where(Log.dag_id.in_(self.value or set()) or Log.dag_id.is_(None)) + return select.where(or_(Log.dag_id.in_(self.value or set()), Log.dag_id.is_(None))) + + class PermittedTIFilter(PermittedDagFilter): """A parameter that filters the permitted task instances for the user.""" @@ -223,6 +234,9 @@ def depends_permitted_dags_filter( ReadableTIFilterDep = Annotated[ PermittedTIFilter, Depends(permitted_dag_filter_factory("GET", PermittedTIFilter)) ] +ReadableEventLogsFilterDep = Annotated[ + PermittedTIFilter, Depends(permitted_dag_filter_factory("GET", PermittedEventLogFilter)) +] ReadableXComFilterDep = Annotated[ PermittedXComFilter, Depends(permitted_dag_filter_factory("GET", PermittedXComFilter)) ] diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py index dd87e8d7fac2c..0d68107ea1a3f 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py @@ -302,7 +302,7 @@ class TestGetEventLogs(TestEventLogsEndpoint): def test_get_event_logs( self, test_client, query_params, expected_status_code, expected_total_entries, expected_events ): - with assert_queries_count(2): + with assert_queries_count(3): response = test_client.get("/eventLogs", params=query_params) assert response.status_code == expected_status_code if expected_status_code != 200: @@ -341,7 +341,7 @@ def test_get_event_logs( def test_get_event_logs_order_by( self, test_client, query_params, expected_status_code, expected_total_entries, expected_events ): - with assert_queries_count(2): + with assert_queries_count(3): response = test_client.get("/eventLogs", params=query_params) assert response.status_code == expected_status_code if expected_status_code != 200: