diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 84297724ce025..363a7fbdc4b0f 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from sqlalchemy.sql import Select - from airflow.api_fastapi.common.parameters import BaseParam + from airflow.api_fastapi.core_api.base import OrmClause def _get_session() -> Session: @@ -47,7 +47,7 @@ def _get_session() -> Session: def apply_filters_to_select( - *, statement: Select, filters: Sequence[BaseParam | None] | None = None + *, statement: Select, filters: Sequence[OrmClause | None] | None = None ) -> Select: if filters is None: return statement @@ -71,10 +71,10 @@ async def _get_async_session() -> AsyncSession: async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -84,10 +84,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -96,10 +96,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam | None] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause | None] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: bool = True, ) -> tuple[Select, int | None]: @@ -129,10 +129,10 @@ async def paginated_select_async( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -142,10 +142,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -155,10 +155,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: bool = True, ) -> tuple[Select, int | None]: diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index f69d64bd5d031..d7ce038d10411 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -40,6 +40,7 @@ from sqlalchemy import Column, and_, case, or_ from sqlalchemy.inspection import inspect +from airflow.api_fastapi.core_api.base import OrmClause from airflow.models import Base from airflow.models.asset import ( AssetAliasModel, @@ -65,18 +66,14 @@ T = TypeVar("T") -class BaseParam(Generic[T], ABC): - """Base class for filters.""" +class BaseParam(OrmClause[T], ABC): + """Base class for path or query parameters with ORM transformation.""" def __init__(self, value: T | None = None, skip_none: bool = True) -> None: - self.value = value + super().__init__(value) self.attribute: ColumnElement | None = None self.skip_none = skip_none - @abstractmethod - def to_orm(self, select: Select) -> Select: - pass - def set_value(self, value: T | None) -> Self: self.value = value return self diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py index d88ec1757eb60..887f528f197ef 100644 --- a/airflow/api_fastapi/core_api/base.py +++ b/airflow/api_fastapi/core_api/base.py @@ -16,8 +16,16 @@ # under the License. from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + from pydantic import BaseModel as PydanticBaseModel, ConfigDict +if TYPE_CHECKING: + from sqlalchemy.sql import Select + +T = TypeVar("T") + class BaseModel(PydanticBaseModel): """ @@ -39,3 +47,18 @@ class StrictBaseModel(BaseModel): """ model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid") + + +class OrmClause(Generic[T], ABC): + """ + Base class for filtering clauses with paginated_select. + + The subclasses should implement the `to_orm` method and set the `value` attribute. + """ + + def __init__(self, value: T | None = None): + self.value = value + + @abstractmethod + def to_orm(self, select: Select) -> Select: + pass diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 7555e3e478856..ae00f8b688179 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3058,6 +3058,8 @@ paths: summary: Get Dags description: Get all DAGs. operationId: get_dags + security: + - OAuth2PasswordBearer: [] parameters: - name: limit in: query @@ -3223,6 +3225,8 @@ paths: summary: Patch Dags description: Patch multiple DAGs. operationId: patch_dags + security: + - OAuth2PasswordBearer: [] parameters: - name: update_mask in: query @@ -3358,6 +3362,8 @@ paths: summary: Get Dag description: Get basic information about a DAG. operationId: get_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3408,6 +3414,8 @@ paths: summary: Patch Dag description: Patch the specific DAG. operationId: patch_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3474,6 +3482,8 @@ paths: summary: Delete Dag description: Delete the specific DAG. operationId: delete_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3524,6 +3534,8 @@ paths: summary: Get Dag Details description: Get details of DAG. operationId: get_dag_details + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 01c44ed29a0c1..c5d95f7821053 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -57,6 +57,11 @@ DAGResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ( + EditableDagsFilterDep, + ReadableDagsFilterDep, + requires_access_dag, +) from airflow.api_fastapi.logging.decorators import action_logging from airflow.exceptions import AirflowException, DagNotFound from airflow.models import DAG, DagModel @@ -65,7 +70,7 @@ dags_router = AirflowRouter(tags=["DAG"], prefix="/dags") -@dags_router.get("") +@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))]) def get_dags( limit: QueryLimit, offset: QueryOffset, @@ -105,6 +110,7 @@ def get_dags( ).dynamic_depends() ), ], + readable_dags_filter: ReadableDagsFilterDep, session: SessionDep, ) -> DAGCollectionResponse: """Get all DAGs.""" @@ -132,6 +138,7 @@ def get_dags( tags, owners, last_dag_run_state, + readable_dags_filter, ], order_by=order_by, offset=offset, @@ -156,6 +163,7 @@ def get_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), + dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: """Get basic information about a DAG.""" @@ -182,6 +190,7 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse: """Get details of DAG.""" @@ -208,7 +217,7 @@ def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDe status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(action_logging())], + dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())], ) def patch_dag( dag_id: str, @@ -251,7 +260,7 @@ def patch_dag( status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(action_logging())], + dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())], ) def patch_dags( patch_body: DAGPatchBody, @@ -263,6 +272,7 @@ def patch_dags( only_active: QueryOnlyActiveFilter, paused: QueryPausedFilter, last_dag_run_state: QueryLastDagRunStateFilter, + editable_dags_filter: EditableDagsFilterDep, session: SessionDep, update_mask: list[str] | None = Query(None), ) -> DAGCollectionResponse: @@ -283,7 +293,7 @@ def patch_dags( dags_select, total_entries = paginated_select( statement=generate_dag_with_latest_run_query(), - filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], + filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state, editable_dags_filter], order_by=None, offset=offset, limit=limit, @@ -313,7 +323,7 @@ def patch_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), - dependencies=[Depends(action_logging())], + dependencies=[Depends(requires_access_dag(method="DELETE")), Depends(action_logging())], ) def delete_dag( dag_id: str, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 2e5745ffed272..7ef10be39e3e2 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -36,11 +36,15 @@ PoolDetails, VariableDetails, ) +from airflow.api_fastapi.core_api.base import OrmClause from airflow.configuration import conf +from airflow.models.dag import DagModel from airflow.utils.jwt_signer import JWTSigner, get_signing_key if TYPE_CHECKING: - from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod + from sqlalchemy.sql import Select + + from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -63,6 +67,9 @@ def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser: raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden") +GetUserDep = Annotated[BaseUser, Depends(get_user)] + + async def get_user_with_exception_handling(request: Request) -> BaseUser | None: # Currently the UI does not support JWT authentication, this method defines a fallback if no token is provided by the UI. # We can remove this method when issue https://github.com/apache/airflow/issues/44884 is done. @@ -80,12 +87,14 @@ async def get_user_with_exception_handling(request: Request) -> BaseUser | None: return get_user(token_str) -def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable: +def requires_access_dag( + method: ResourceMethod, access_entity: DagAccessEntity | None = None +) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: - dag_id = request.path_params.get("dag_id") or request.query_params.get("dag_id") + dag_id: str | None = request.path_params.get("dag_id") _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( @@ -96,10 +105,40 @@ def inner( return inner +class PermittedDagFilter(OrmClause[set[str]]): + """A parameter that filters the permitted dags for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(DagModel.dag_id.in_(self.value)) + + +def permitted_dag_filter_factory(method: ResourceMethod) -> Callable[[Request, BaseUser], PermittedDagFilter]: + """ + Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. + + :param method: whether filter readable or writable. + :return: The callable that can be used as Depends in FastAPI. + """ + + def depends_permitted_dags_filter( + request: Request, + user: GetUserDep, + ) -> PermittedDagFilter: + auth_manager: BaseAuthManager = request.app.state.auth_manager + permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, method=method) + return PermittedDagFilter(permitted_dags) + + return depends_permitted_dags_filter + + +EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("PUT"))] +ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("GET"))] + + def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: pool_name = request.path_params.get("pool_name") @@ -115,7 +154,7 @@ def inner( def requires_access_connection(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: connection_id = request.path_params.get("connection_id") diff --git a/docker_tests/test_docker_compose_quick_start.py b/docker_tests/test_docker_compose_quick_start.py index c91dbb36dba6d..ccf4ef18fcd2e 100644 --- a/docker_tests/test_docker_compose_quick_start.py +++ b/docker_tests/test_docker_compose_quick_start.py @@ -18,10 +18,12 @@ import json import os +import re import shlex from pprint import pprint from shutil import copyfile from time import sleep +from urllib.parse import parse_qs, urlparse import pytest import requests @@ -34,18 +36,67 @@ # isort:on (needed to workaround isort bug) +DOCKER_COMPOSE_HOST_PORT = os.environ.get("HOST_PORT", "localhost:8080") AIRFLOW_WWW_USER_USERNAME = os.environ.get("_AIRFLOW_WWW_USER_USERNAME", "airflow") AIRFLOW_WWW_USER_PASSWORD = os.environ.get("_AIRFLOW_WWW_USER_PASSWORD", "airflow") DAG_ID = "example_bash_operator" DAG_RUN_ID = "test_dag_run_id" -def api_request(method: str, path: str, base_url: str = "http://localhost:8080/public", **kwargs) -> dict: +def get_jwt_token() -> str: + """Get the JWT token. + + Note: API server is still using FAB Auth Manager. + + Steps: + 1. Get the login page to get the csrf token + - The csrf token is in the hidden input field with id "csrf_token" + 2. Login with the username and password + - Must use the same session to keep the csrf token session + 3. Extract the JWT token from the redirect url + - Expected to have a connection error + - The redirect url should have the JWT token as a query parameter + + :return: The JWT token + """ + # get csrf token from login page + session = requests.Session() + get_login_form_response = session.get(f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login") + csrf_token = re.search( + r'', + get_login_form_response.text, + ) + assert csrf_token, "Failed to get csrf token from login page" + csrf_token_str = csrf_token.group(1) + assert csrf_token_str, "Failed to get csrf token from login page" + # login with form data + login_response = session.post( + f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login", + data={ + "username": AIRFLOW_WWW_USER_USERNAME, + "password": AIRFLOW_WWW_USER_PASSWORD, + "csrf_token": csrf_token_str, + }, + ) + redirect_url = login_response.url + # ensure redirect_url is a string + redirect_url_str = str(redirect_url) if redirect_url is not None else "" + assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" + parsed_url = urlparse(redirect_url_str) + query_params = parse_qs(str(parsed_url.query)) + jwt_token_list = query_params.get("token") + jwt_token = jwt_token_list[0] if jwt_token_list else None + assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" + return jwt_token + + +def api_request( + method: str, path: str, base_url: str = f"http://{DOCKER_COMPOSE_HOST_PORT}/public", **kwargs +) -> dict: response = requests.request( method=method, url=f"{base_url}/{path}", - auth=(AIRFLOW_WWW_USER_USERNAME, AIRFLOW_WWW_USER_PASSWORD), - headers={"Content-Type": "application/json"}, + headers={"Authorization": f"Bearer {get_jwt_token()}", "Content-Type": "application/json"}, **kwargs, ) response.raise_for_status() diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 31e1924c18ad8..31248b02ac420 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -25,6 +25,7 @@ from datetime import datetime, timezone from pathlib import Path from subprocess import check_call, check_output +from urllib.parse import parse_qs, urlparse import pytest import requests @@ -59,9 +60,12 @@ class BaseK8STest: @pytest.fixture(autouse=True) def base_tests_setup(self, request): - self.set_api_server_base_url_config() - self.rollout_restart_deployment("airflow-api-server") - self.ensure_deployment_health("airflow-api-server") + if self.set_api_server_base_url_config(): + # only restart the deployment if the configmap was updated + # speed up the test and make the airflow-api-server deployment more stable + self.rollout_restart_deployment("airflow-api-server") + self.ensure_deployment_health("airflow-api-server") + # Replacement for unittests.TestCase.id() self.test_id = f"{request.node.cls.__name__}_{request.node.name}" self.session = self._get_session_with_retries() @@ -126,17 +130,92 @@ def _delete_airflow_pod(name=""): if names: check_call(["kubectl", "delete", "pod", names[0]]) + @staticmethod + def _get_jwt_token(username: str, password: str) -> str: + """Get the JWT token for the given username and password. + + Note: API server is still using FAB Auth Manager. + + Steps: + 1. Get the login page to get the csrf token + - The csrf token is in the hidden input field with id "csrf_token" + 2. Login with the username and password + - Must use the same session to keep the csrf token session + 3. Extract the JWT token from the redirect url + - Expected to have a connection error + - The redirect url should have the JWT token as a query parameter + + :param session: The session to use for the request + :param username: The username to use for the login + :param password: The password to use for the login + :return: The JWT token + """ + # get csrf token from login page + retry = Retry(total=5, backoff_factor=10) + session = requests.Session() + session.mount("http://", HTTPAdapter(max_retries=retry)) + session.mount("https://", HTTPAdapter(max_retries=retry)) + get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login") + csrf_token = re.search( + r'', + get_login_form_response.text, + ) + assert csrf_token, "Failed to get csrf token from login page" + csrf_token_str = csrf_token.group(1) + assert csrf_token_str, "Failed to get csrf token from login page" + # login with form data + login_response = session.post( + f"http://{KUBERNETES_HOST_PORT}/auth/login", + data={"username": username, "password": password, "csrf_token": csrf_token_str}, + ) + redirect_url = login_response.url + # ensure redirect_url is a string + redirect_url_str = str(redirect_url) if redirect_url is not None else "" + assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" + parsed_url = urlparse(redirect_url_str) + query_params = parse_qs(str(parsed_url.query)) + jwt_token_list = query_params.get("token") + jwt_token = jwt_token_list[0] if jwt_token_list else None + assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" + return jwt_token + def _get_session_with_retries(self): + class JWTRefreshAdapter(HTTPAdapter): + def __init__(self, base_instance, **kwargs): + self.base_instance = base_instance + super().__init__(**kwargs) + + def send(self, request, **kwargs): + response = super().send(request, **kwargs) + if response.status_code in (401, 403): + # Refresh token and update the Authorization header with retry logic. + attempts = 0 + jwt_token = None + while attempts < 5: + try: + jwt_token = self.base_instance._get_jwt_token("admin", "admin") + break + except Exception: + attempts += 1 + time.sleep(1) + if jwt_token is None: + raise Exception("Failed to refresh JWT token after 5 attempts") + request.headers["Authorization"] = f"Bearer {jwt_token}" + response = super().send(request, **kwargs) + return response + + jwt_token = self._get_jwt_token("admin", "admin") session = requests.Session() - session.auth = ("admin", "admin") + session.headers.update({"Authorization": f"Bearer {jwt_token}"}) retries = Retry( - total=3, + total=5, backoff_factor=10, status_forcelist=[404], allowed_methods=Retry.DEFAULT_ALLOWED_METHODS | frozenset(["PATCH", "POST"]), ) - session.mount("http://", HTTPAdapter(max_retries=retries)) - session.mount("https://", HTTPAdapter(max_retries=retries)) + adapter = JWTRefreshAdapter(self, max_retries=retries) + session.mount("http://", adapter) + session.mount("https://", adapter) return session def _ensure_airflow_api_server_is_healthy(self): @@ -236,8 +315,11 @@ def _parse_airflow_cfg_dict_as_escaped_toml(self, airflow_cfg_dict: dict) -> str # escape newlines and double quotes return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"') - def set_api_server_base_url_config(self): - """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap.""" + def set_api_server_base_url_config(self) -> bool: + """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap. + + :return: True if the configmap was updated successfully, False otherwise + """ configmap_name = "airflow-config" configmap_key = "airflow.cfg" original_configmap_json_str = check_output( @@ -250,7 +332,7 @@ def set_api_server_base_url_config(self): airflow_cfg_dict = self._parse_airflow_cfg_as_dict(original_airflow_cfg) airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}" # update the configmap with the new airflow.cfg - check_call( + patch_configmap_result = check_output( [ "kubectl", "patch", @@ -263,7 +345,10 @@ def set_api_server_base_url_config(self): "-p", f'{{"data": {{"{configmap_key}": "{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}', ] - ) + ).decode() + if "(no change)" in patch_configmap_result: + return False + return True def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0 diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py index 63d1389b17103..8a7596f3cda70 100644 --- a/kubernetes_tests/test_kubernetes_executor.py +++ b/kubernetes_tests/test_kubernetes_executor.py @@ -50,7 +50,7 @@ def test_integration_run_dag(self): timeout=300, ) - @pytest.mark.execution_timeout(300) + @pytest.mark.execution_timeout(500) def test_integration_run_dag_with_scheduler_failure(self): dag_id = "example_kubernetes_executor" diff --git a/kubernetes_tests/test_other_executors.py b/kubernetes_tests/test_other_executors.py index f8203b069b404..327e252825a37 100644 --- a/kubernetes_tests/test_other_executors.py +++ b/kubernetes_tests/test_other_executors.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import time - import pytest from kubernetes_tests.test_base import ( @@ -68,8 +66,7 @@ def test_integration_run_dag_with_scheduler_failure(self): dag_run_id, logical_date = self.start_job_in_kubernetes(dag_id, self.host) self._delete_airflow_pod("scheduler") - - time.sleep(10) # give time for pod to restart + self.ensure_deployment_health("airflow-scheduler") # Wait some time for the operator to complete self.monitor_task( diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 21f1bd0235c52..93c7c20ea3afa 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -235,13 +235,31 @@ class TestGetDags(TestDagEndpoint): ) def test_get_dags(self, test_client, query_params, expected_total_entries, expected_ids): response = test_client.get("/public/dags", params=query_params) - assert response.status_code == 200 body = response.json() assert body["total_entries"] == expected_total_entries assert [dag["dag_id"] for dag in body["dags"]] == expected_ids + @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): + mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} + response = test_client.get("/public/dags") + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="GET") + assert response.status_code == 200 + body = response.json() + + assert body["total_entries"] == 2 + assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] + + def test_get_dags_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/dags") + assert response.status_code == 401 + + def test_get_dags_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/dags") + assert response.status_code == 403 + class TestPatchDag(TestDagEndpoint): """Unit tests for Patch DAG.""" @@ -268,6 +286,14 @@ def test_patch_dag( assert body["is_paused"] == expected_is_paused check_last_log(session, dag_id=dag_id, event="patch_dag", logical_date=None) + def test_patch_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) + assert response.status_code == 401 + + def test_patch_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) + assert response.status_code == 403 + class TestPatchDags(TestDagEndpoint): """Unit tests for Patch DAGs.""" @@ -333,6 +359,26 @@ def test_patch_dags( assert paused_dag_ids == expected_paused_ids check_last_log(session, dag_id=DAG1_ID, event="patch_dag", logical_date=None) + @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): + mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} + response = test_client.patch( + "/public/dags", json={"is_paused": False}, params={"only_active": False, "dag_id_pattern": "~"} + ) + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="PUT") + assert response.status_code == 200 + body = response.json() + + assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] + + def test_patch_dags_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/dags", json={"is_paused": True}) + assert response.status_code == 401 + + def test_patch_dags_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/dags", json={"is_paused": True}) + assert response.status_code == 403 + class TestDagDetails(TestDagEndpoint): """Unit tests for DAG Details.""" @@ -414,6 +460,14 @@ def test_dag_details( } assert res_json == expected + def test_dag_details_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/details") + assert response.status_code == 401 + + def test_dag_details_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/details") + assert response.status_code == 403 + class TestGetDag(TestDagEndpoint): """Unit tests for Get DAG.""" @@ -462,6 +516,14 @@ def test_get_dag(self, test_client, query_params, dag_id, expected_status_code, } assert res_json == expected + def test_get_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}") + assert response.status_code == 401 + + def test_get_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}") + assert response.status_code == 403 + class TestDeleteDAG(TestDagEndpoint): """Unit tests for Delete DAG.""" @@ -521,5 +583,14 @@ def test_delete_dag( details_response = test_client.get(f"{API_PREFIX}/{dag_id}/details") assert details_response.status_code == status_code_details + if details_response.status_code == 204: check_last_log(session, dag_id=dag_id, event="delete_dag", logical_date=None) + + def test_delete_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") + assert response.status_code == 401 + + def test_delete_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") + assert response.status_code == 403 diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py index f777cda5e83f0..193cc67798f91 100644 --- a/tests/api_fastapi/core_api/test_security.py +++ b/tests/api_fastapi/core_api/test_security.py @@ -88,11 +88,10 @@ def test_requires_access_dag_authorized(self, mock_get_auth_manager): auth_manager = Mock() auth_manager.is_authorized_dag.return_value = True mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params.return_value = {} - mock_request = Mock() - mock_request.path_params.return_value = {"dag_id": "test"} - - requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) auth_manager.is_authorized_dag.assert_called_once() @@ -101,11 +100,13 @@ def test_requires_access_dag_unauthorized(self, mock_get_auth_manager): auth_manager = Mock() auth_manager.is_authorized_dag.return_value = False mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params.return_value = {} mock_request = Mock() mock_request.path_params.return_value = {} with pytest.raises(HTTPException, match="Forbidden"): - requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) auth_manager.is_authorized_dag.assert_called_once()