diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py
index 84297724ce025..363a7fbdc4b0f 100644
--- a/airflow/api_fastapi/common/db/common.py
+++ b/airflow/api_fastapi/common/db/common.py
@@ -35,7 +35,7 @@
if TYPE_CHECKING:
from sqlalchemy.sql import Select
- from airflow.api_fastapi.common.parameters import BaseParam
+ from airflow.api_fastapi.core_api.base import OrmClause
def _get_session() -> Session:
@@ -47,7 +47,7 @@ def _get_session() -> Session:
def apply_filters_to_select(
- *, statement: Select, filters: Sequence[BaseParam | None] | None = None
+ *, statement: Select, filters: Sequence[OrmClause | None] | None = None
) -> Select:
if filters is None:
return statement
@@ -71,10 +71,10 @@ async def _get_async_session() -> AsyncSession:
async def paginated_select_async(
*,
statement: Select,
- filters: Sequence[BaseParam] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...
@@ -84,10 +84,10 @@ async def paginated_select_async(
async def paginated_select_async(
*,
statement: Select,
- filters: Sequence[BaseParam] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...
@@ -96,10 +96,10 @@ async def paginated_select_async(
async def paginated_select_async(
*,
statement: Select,
- filters: Sequence[BaseParam | None] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause | None] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
@@ -129,10 +129,10 @@ async def paginated_select_async(
def paginated_select(
*,
statement: Select,
- filters: Sequence[BaseParam] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...
@@ -142,10 +142,10 @@ def paginated_select(
def paginated_select(
*,
statement: Select,
- filters: Sequence[BaseParam] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...
@@ -155,10 +155,10 @@ def paginated_select(
def paginated_select(
*,
statement: Select,
- filters: Sequence[BaseParam] | None = None,
- order_by: BaseParam | None = None,
- offset: BaseParam | None = None,
- limit: BaseParam | None = None,
+ filters: Sequence[OrmClause] | None = None,
+ order_by: OrmClause | None = None,
+ offset: OrmClause | None = None,
+ limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py
index f69d64bd5d031..d7ce038d10411 100644
--- a/airflow/api_fastapi/common/parameters.py
+++ b/airflow/api_fastapi/common/parameters.py
@@ -40,6 +40,7 @@
from sqlalchemy import Column, and_, case, or_
from sqlalchemy.inspection import inspect
+from airflow.api_fastapi.core_api.base import OrmClause
from airflow.models import Base
from airflow.models.asset import (
AssetAliasModel,
@@ -65,18 +66,14 @@
T = TypeVar("T")
-class BaseParam(Generic[T], ABC):
- """Base class for filters."""
+class BaseParam(OrmClause[T], ABC):
+ """Base class for path or query parameters with ORM transformation."""
def __init__(self, value: T | None = None, skip_none: bool = True) -> None:
- self.value = value
+ super().__init__(value)
self.attribute: ColumnElement | None = None
self.skip_none = skip_none
- @abstractmethod
- def to_orm(self, select: Select) -> Select:
- pass
-
def set_value(self, value: T | None) -> Self:
self.value = value
return self
diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py
index d88ec1757eb60..887f528f197ef 100644
--- a/airflow/api_fastapi/core_api/base.py
+++ b/airflow/api_fastapi/core_api/base.py
@@ -16,8 +16,16 @@
# under the License.
from __future__ import annotations
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Generic, TypeVar
+
from pydantic import BaseModel as PydanticBaseModel, ConfigDict
+if TYPE_CHECKING:
+ from sqlalchemy.sql import Select
+
+T = TypeVar("T")
+
class BaseModel(PydanticBaseModel):
"""
@@ -39,3 +47,18 @@ class StrictBaseModel(BaseModel):
"""
model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid")
+
+
+class OrmClause(Generic[T], ABC):
+ """
+ Base class for filtering clauses with paginated_select.
+
+ The subclasses should implement the `to_orm` method and set the `value` attribute.
+ """
+
+ def __init__(self, value: T | None = None):
+ self.value = value
+
+ @abstractmethod
+ def to_orm(self, select: Select) -> Select:
+ pass
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 7555e3e478856..ae00f8b688179 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -3058,6 +3058,8 @@ paths:
summary: Get Dags
description: Get all DAGs.
operationId: get_dags
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
@@ -3223,6 +3225,8 @@ paths:
summary: Patch Dags
description: Patch multiple DAGs.
operationId: patch_dags
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: update_mask
in: query
@@ -3358,6 +3362,8 @@ paths:
summary: Get Dag
description: Get basic information about a DAG.
operationId: get_dag
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
@@ -3408,6 +3414,8 @@ paths:
summary: Patch Dag
description: Patch the specific DAG.
operationId: patch_dag
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
@@ -3474,6 +3482,8 @@ paths:
summary: Delete Dag
description: Delete the specific DAG.
operationId: delete_dag
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
@@ -3524,6 +3534,8 @@ paths:
summary: Get Dag Details
description: Get details of DAG.
operationId: get_dag_details
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py
index 01c44ed29a0c1..c5d95f7821053 100644
--- a/airflow/api_fastapi/core_api/routes/public/dags.py
+++ b/airflow/api_fastapi/core_api/routes/public/dags.py
@@ -57,6 +57,11 @@
DAGResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import (
+ EditableDagsFilterDep,
+ ReadableDagsFilterDep,
+ requires_access_dag,
+)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DAG, DagModel
@@ -65,7 +70,7 @@
dags_router = AirflowRouter(tags=["DAG"], prefix="/dags")
-@dags_router.get("")
+@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))])
def get_dags(
limit: QueryLimit,
offset: QueryOffset,
@@ -105,6 +110,7 @@ def get_dags(
).dynamic_depends()
),
],
+ readable_dags_filter: ReadableDagsFilterDep,
session: SessionDep,
) -> DAGCollectionResponse:
"""Get all DAGs."""
@@ -132,6 +138,7 @@ def get_dags(
tags,
owners,
last_dag_run_state,
+ readable_dags_filter,
],
order_by=order_by,
offset=offset,
@@ -156,6 +163,7 @@ def get_dags(
status.HTTP_422_UNPROCESSABLE_ENTITY,
]
),
+ dependencies=[Depends(requires_access_dag(method="GET"))],
)
def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse:
"""Get basic information about a DAG."""
@@ -182,6 +190,7 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse:
status.HTTP_404_NOT_FOUND,
]
),
+ dependencies=[Depends(requires_access_dag(method="GET"))],
)
def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse:
"""Get details of DAG."""
@@ -208,7 +217,7 @@ def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDe
status.HTTP_404_NOT_FOUND,
]
),
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())],
)
def patch_dag(
dag_id: str,
@@ -251,7 +260,7 @@ def patch_dag(
status.HTTP_404_NOT_FOUND,
]
),
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())],
)
def patch_dags(
patch_body: DAGPatchBody,
@@ -263,6 +272,7 @@ def patch_dags(
only_active: QueryOnlyActiveFilter,
paused: QueryPausedFilter,
last_dag_run_state: QueryLastDagRunStateFilter,
+ editable_dags_filter: EditableDagsFilterDep,
session: SessionDep,
update_mask: list[str] | None = Query(None),
) -> DAGCollectionResponse:
@@ -283,7 +293,7 @@ def patch_dags(
dags_select, total_entries = paginated_select(
statement=generate_dag_with_latest_run_query(),
- filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state],
+ filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state, editable_dags_filter],
order_by=None,
offset=offset,
limit=limit,
@@ -313,7 +323,7 @@ def patch_dags(
status.HTTP_422_UNPROCESSABLE_ENTITY,
]
),
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_dag(method="DELETE")), Depends(action_logging())],
)
def delete_dag(
dag_id: str,
diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py
index 2e5745ffed272..7ef10be39e3e2 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -36,11 +36,15 @@
PoolDetails,
VariableDetails,
)
+from airflow.api_fastapi.core_api.base import OrmClause
from airflow.configuration import conf
+from airflow.models.dag import DagModel
from airflow.utils.jwt_signer import JWTSigner, get_signing_key
if TYPE_CHECKING:
- from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
+ from sqlalchemy.sql import Select
+
+ from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@@ -63,6 +67,9 @@ def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
+GetUserDep = Annotated[BaseUser, Depends(get_user)]
+
+
async def get_user_with_exception_handling(request: Request) -> BaseUser | None:
# Currently the UI does not support JWT authentication, this method defines a fallback if no token is provided by the UI.
# We can remove this method when issue https://github.com/apache/airflow/issues/44884 is done.
@@ -80,12 +87,14 @@ async def get_user_with_exception_handling(request: Request) -> BaseUser | None:
return get_user(token_str)
-def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable:
+def requires_access_dag(
+ method: ResourceMethod, access_entity: DagAccessEntity | None = None
+) -> Callable[[Request, BaseUser], None]:
def inner(
request: Request,
- user: Annotated[BaseUser, Depends(get_user)],
+ user: GetUserDep,
) -> None:
- dag_id = request.path_params.get("dag_id") or request.query_params.get("dag_id")
+ dag_id: str | None = request.path_params.get("dag_id")
_requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_dag(
@@ -96,10 +105,40 @@ def inner(
return inner
+class PermittedDagFilter(OrmClause[set[str]]):
+ """A parameter that filters the permitted dags for the user."""
+
+ def to_orm(self, select: Select) -> Select:
+ return select.where(DagModel.dag_id.in_(self.value))
+
+
+def permitted_dag_filter_factory(method: ResourceMethod) -> Callable[[Request, BaseUser], PermittedDagFilter]:
+ """
+ Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user.
+
+ :param method: whether filter readable or writable.
+ :return: The callable that can be used as Depends in FastAPI.
+ """
+
+ def depends_permitted_dags_filter(
+ request: Request,
+ user: GetUserDep,
+ ) -> PermittedDagFilter:
+ auth_manager: BaseAuthManager = request.app.state.auth_manager
+ permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, method=method)
+ return PermittedDagFilter(permitted_dags)
+
+ return depends_permitted_dags_filter
+
+
+EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("PUT"))]
+ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("GET"))]
+
+
def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]:
def inner(
request: Request,
- user: Annotated[BaseUser, Depends(get_user)],
+ user: GetUserDep,
) -> None:
pool_name = request.path_params.get("pool_name")
@@ -115,7 +154,7 @@ def inner(
def requires_access_connection(method: ResourceMethod) -> Callable[[Request, BaseUser], None]:
def inner(
request: Request,
- user: Annotated[BaseUser, Depends(get_user)],
+ user: GetUserDep,
) -> None:
connection_id = request.path_params.get("connection_id")
diff --git a/docker_tests/test_docker_compose_quick_start.py b/docker_tests/test_docker_compose_quick_start.py
index c91dbb36dba6d..ccf4ef18fcd2e 100644
--- a/docker_tests/test_docker_compose_quick_start.py
+++ b/docker_tests/test_docker_compose_quick_start.py
@@ -18,10 +18,12 @@
import json
import os
+import re
import shlex
from pprint import pprint
from shutil import copyfile
from time import sleep
+from urllib.parse import parse_qs, urlparse
import pytest
import requests
@@ -34,18 +36,67 @@
# isort:on (needed to workaround isort bug)
+DOCKER_COMPOSE_HOST_PORT = os.environ.get("HOST_PORT", "localhost:8080")
AIRFLOW_WWW_USER_USERNAME = os.environ.get("_AIRFLOW_WWW_USER_USERNAME", "airflow")
AIRFLOW_WWW_USER_PASSWORD = os.environ.get("_AIRFLOW_WWW_USER_PASSWORD", "airflow")
DAG_ID = "example_bash_operator"
DAG_RUN_ID = "test_dag_run_id"
-def api_request(method: str, path: str, base_url: str = "http://localhost:8080/public", **kwargs) -> dict:
+def get_jwt_token() -> str:
+ """Get the JWT token.
+
+ Note: API server is still using FAB Auth Manager.
+
+ Steps:
+ 1. Get the login page to get the csrf token
+ - The csrf token is in the hidden input field with id "csrf_token"
+ 2. Login with the username and password
+ - Must use the same session to keep the csrf token session
+ 3. Extract the JWT token from the redirect url
+ - Expected to have a connection error
+ - The redirect url should have the JWT token as a query parameter
+
+ :return: The JWT token
+ """
+ # get csrf token from login page
+ session = requests.Session()
+ get_login_form_response = session.get(f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login")
+ csrf_token = re.search(
+ r'',
+ get_login_form_response.text,
+ )
+ assert csrf_token, "Failed to get csrf token from login page"
+ csrf_token_str = csrf_token.group(1)
+ assert csrf_token_str, "Failed to get csrf token from login page"
+ # login with form data
+ login_response = session.post(
+ f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login",
+ data={
+ "username": AIRFLOW_WWW_USER_USERNAME,
+ "password": AIRFLOW_WWW_USER_PASSWORD,
+ "csrf_token": csrf_token_str,
+ },
+ )
+ redirect_url = login_response.url
+ # ensure redirect_url is a string
+ redirect_url_str = str(redirect_url) if redirect_url is not None else ""
+ assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}"
+ parsed_url = urlparse(redirect_url_str)
+ query_params = parse_qs(str(parsed_url.query))
+ jwt_token_list = query_params.get("token")
+ jwt_token = jwt_token_list[0] if jwt_token_list else None
+ assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}"
+ return jwt_token
+
+
+def api_request(
+ method: str, path: str, base_url: str = f"http://{DOCKER_COMPOSE_HOST_PORT}/public", **kwargs
+) -> dict:
response = requests.request(
method=method,
url=f"{base_url}/{path}",
- auth=(AIRFLOW_WWW_USER_USERNAME, AIRFLOW_WWW_USER_PASSWORD),
- headers={"Content-Type": "application/json"},
+ headers={"Authorization": f"Bearer {get_jwt_token()}", "Content-Type": "application/json"},
**kwargs,
)
response.raise_for_status()
diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py
index 31e1924c18ad8..31248b02ac420 100644
--- a/kubernetes_tests/test_base.py
+++ b/kubernetes_tests/test_base.py
@@ -25,6 +25,7 @@
from datetime import datetime, timezone
from pathlib import Path
from subprocess import check_call, check_output
+from urllib.parse import parse_qs, urlparse
import pytest
import requests
@@ -59,9 +60,12 @@ class BaseK8STest:
@pytest.fixture(autouse=True)
def base_tests_setup(self, request):
- self.set_api_server_base_url_config()
- self.rollout_restart_deployment("airflow-api-server")
- self.ensure_deployment_health("airflow-api-server")
+ if self.set_api_server_base_url_config():
+ # only restart the deployment if the configmap was updated
+ # speed up the test and make the airflow-api-server deployment more stable
+ self.rollout_restart_deployment("airflow-api-server")
+ self.ensure_deployment_health("airflow-api-server")
+
# Replacement for unittests.TestCase.id()
self.test_id = f"{request.node.cls.__name__}_{request.node.name}"
self.session = self._get_session_with_retries()
@@ -126,17 +130,92 @@ def _delete_airflow_pod(name=""):
if names:
check_call(["kubectl", "delete", "pod", names[0]])
+ @staticmethod
+ def _get_jwt_token(username: str, password: str) -> str:
+ """Get the JWT token for the given username and password.
+
+ Note: API server is still using FAB Auth Manager.
+
+ Steps:
+ 1. Get the login page to get the csrf token
+ - The csrf token is in the hidden input field with id "csrf_token"
+ 2. Login with the username and password
+ - Must use the same session to keep the csrf token session
+ 3. Extract the JWT token from the redirect url
+ - Expected to have a connection error
+ - The redirect url should have the JWT token as a query parameter
+
+ :param session: The session to use for the request
+ :param username: The username to use for the login
+ :param password: The password to use for the login
+ :return: The JWT token
+ """
+ # get csrf token from login page
+ retry = Retry(total=5, backoff_factor=10)
+ session = requests.Session()
+ session.mount("http://", HTTPAdapter(max_retries=retry))
+ session.mount("https://", HTTPAdapter(max_retries=retry))
+ get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login")
+ csrf_token = re.search(
+ r'',
+ get_login_form_response.text,
+ )
+ assert csrf_token, "Failed to get csrf token from login page"
+ csrf_token_str = csrf_token.group(1)
+ assert csrf_token_str, "Failed to get csrf token from login page"
+ # login with form data
+ login_response = session.post(
+ f"http://{KUBERNETES_HOST_PORT}/auth/login",
+ data={"username": username, "password": password, "csrf_token": csrf_token_str},
+ )
+ redirect_url = login_response.url
+ # ensure redirect_url is a string
+ redirect_url_str = str(redirect_url) if redirect_url is not None else ""
+ assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}"
+ parsed_url = urlparse(redirect_url_str)
+ query_params = parse_qs(str(parsed_url.query))
+ jwt_token_list = query_params.get("token")
+ jwt_token = jwt_token_list[0] if jwt_token_list else None
+ assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}"
+ return jwt_token
+
def _get_session_with_retries(self):
+ class JWTRefreshAdapter(HTTPAdapter):
+ def __init__(self, base_instance, **kwargs):
+ self.base_instance = base_instance
+ super().__init__(**kwargs)
+
+ def send(self, request, **kwargs):
+ response = super().send(request, **kwargs)
+ if response.status_code in (401, 403):
+ # Refresh token and update the Authorization header with retry logic.
+ attempts = 0
+ jwt_token = None
+ while attempts < 5:
+ try:
+ jwt_token = self.base_instance._get_jwt_token("admin", "admin")
+ break
+ except Exception:
+ attempts += 1
+ time.sleep(1)
+ if jwt_token is None:
+ raise Exception("Failed to refresh JWT token after 5 attempts")
+ request.headers["Authorization"] = f"Bearer {jwt_token}"
+ response = super().send(request, **kwargs)
+ return response
+
+ jwt_token = self._get_jwt_token("admin", "admin")
session = requests.Session()
- session.auth = ("admin", "admin")
+ session.headers.update({"Authorization": f"Bearer {jwt_token}"})
retries = Retry(
- total=3,
+ total=5,
backoff_factor=10,
status_forcelist=[404],
allowed_methods=Retry.DEFAULT_ALLOWED_METHODS | frozenset(["PATCH", "POST"]),
)
- session.mount("http://", HTTPAdapter(max_retries=retries))
- session.mount("https://", HTTPAdapter(max_retries=retries))
+ adapter = JWTRefreshAdapter(self, max_retries=retries)
+ session.mount("http://", adapter)
+ session.mount("https://", adapter)
return session
def _ensure_airflow_api_server_is_healthy(self):
@@ -236,8 +315,11 @@ def _parse_airflow_cfg_dict_as_escaped_toml(self, airflow_cfg_dict: dict) -> str
# escape newlines and double quotes
return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"')
- def set_api_server_base_url_config(self):
- """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap."""
+ def set_api_server_base_url_config(self) -> bool:
+ """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap.
+
+ :return: True if the configmap was updated successfully, False otherwise
+ """
configmap_name = "airflow-config"
configmap_key = "airflow.cfg"
original_configmap_json_str = check_output(
@@ -250,7 +332,7 @@ def set_api_server_base_url_config(self):
airflow_cfg_dict = self._parse_airflow_cfg_as_dict(original_airflow_cfg)
airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}"
# update the configmap with the new airflow.cfg
- check_call(
+ patch_configmap_result = check_output(
[
"kubectl",
"patch",
@@ -263,7 +345,10 @@ def set_api_server_base_url_config(self):
"-p",
f'{{"data": {{"{configmap_key}": "{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}',
]
- )
+ ).decode()
+ if "(no change)" in patch_configmap_result:
+ return False
+ return True
def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout):
tries = 0
diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py
index 63d1389b17103..8a7596f3cda70 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -50,7 +50,7 @@ def test_integration_run_dag(self):
timeout=300,
)
- @pytest.mark.execution_timeout(300)
+ @pytest.mark.execution_timeout(500)
def test_integration_run_dag_with_scheduler_failure(self):
dag_id = "example_kubernetes_executor"
diff --git a/kubernetes_tests/test_other_executors.py b/kubernetes_tests/test_other_executors.py
index f8203b069b404..327e252825a37 100644
--- a/kubernetes_tests/test_other_executors.py
+++ b/kubernetes_tests/test_other_executors.py
@@ -16,8 +16,6 @@
# under the License.
from __future__ import annotations
-import time
-
import pytest
from kubernetes_tests.test_base import (
@@ -68,8 +66,7 @@ def test_integration_run_dag_with_scheduler_failure(self):
dag_run_id, logical_date = self.start_job_in_kubernetes(dag_id, self.host)
self._delete_airflow_pod("scheduler")
-
- time.sleep(10) # give time for pod to restart
+ self.ensure_deployment_health("airflow-scheduler")
# Wait some time for the operator to complete
self.monitor_task(
diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py
index 21f1bd0235c52..93c7c20ea3afa 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dags.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dags.py
@@ -235,13 +235,31 @@ class TestGetDags(TestDagEndpoint):
)
def test_get_dags(self, test_client, query_params, expected_total_entries, expected_ids):
response = test_client.get("/public/dags", params=query_params)
-
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == expected_total_entries
assert [dag["dag_id"] for dag in body["dags"]] == expected_ids
+ @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids")
+ def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client):
+ mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID}
+ response = test_client.get("/public/dags")
+ mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="GET")
+ assert response.status_code == 200
+ body = response.json()
+
+ assert body["total_entries"] == 2
+ assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID]
+
+ def test_get_dags_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.get("/public/dags")
+ assert response.status_code == 401
+
+ def test_get_dags_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.get("/public/dags")
+ assert response.status_code == 403
+
class TestPatchDag(TestDagEndpoint):
"""Unit tests for Patch DAG."""
@@ -268,6 +286,14 @@ def test_patch_dag(
assert body["is_paused"] == expected_is_paused
check_last_log(session, dag_id=dag_id, event="patch_dag", logical_date=None)
+ def test_patch_dag_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True})
+ assert response.status_code == 401
+
+ def test_patch_dag_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True})
+ assert response.status_code == 403
+
class TestPatchDags(TestDagEndpoint):
"""Unit tests for Patch DAGs."""
@@ -333,6 +359,26 @@ def test_patch_dags(
assert paused_dag_ids == expected_paused_ids
check_last_log(session, dag_id=DAG1_ID, event="patch_dag", logical_date=None)
+ @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids")
+ def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client):
+ mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID}
+ response = test_client.patch(
+ "/public/dags", json={"is_paused": False}, params={"only_active": False, "dag_id_pattern": "~"}
+ )
+ mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="PUT")
+ assert response.status_code == 200
+ body = response.json()
+
+ assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID]
+
+ def test_patch_dags_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.patch("/public/dags", json={"is_paused": True})
+ assert response.status_code == 401
+
+ def test_patch_dags_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.patch("/public/dags", json={"is_paused": True})
+ assert response.status_code == 403
+
class TestDagDetails(TestDagEndpoint):
"""Unit tests for DAG Details."""
@@ -414,6 +460,14 @@ def test_dag_details(
}
assert res_json == expected
+ def test_dag_details_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/details")
+ assert response.status_code == 401
+
+ def test_dag_details_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/details")
+ assert response.status_code == 403
+
class TestGetDag(TestDagEndpoint):
"""Unit tests for Get DAG."""
@@ -462,6 +516,14 @@ def test_get_dag(self, test_client, query_params, dag_id, expected_status_code,
}
assert res_json == expected
+ def test_get_dag_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}")
+ assert response.status_code == 401
+
+ def test_get_dag_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}")
+ assert response.status_code == 403
+
class TestDeleteDAG(TestDagEndpoint):
"""Unit tests for Delete DAG."""
@@ -521,5 +583,14 @@ def test_delete_dag(
details_response = test_client.get(f"{API_PREFIX}/{dag_id}/details")
assert details_response.status_code == status_code_details
+
if details_response.status_code == 204:
check_last_log(session, dag_id=dag_id, event="delete_dag", logical_date=None)
+
+ def test_delete_dag_should_response_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}")
+ assert response.status_code == 401
+
+ def test_delete_dag_should_response_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.delete(f"{API_PREFIX}/{DAG1_ID}")
+ assert response.status_code == 403
diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py
index f777cda5e83f0..193cc67798f91 100644
--- a/tests/api_fastapi/core_api/test_security.py
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -88,11 +88,10 @@ def test_requires_access_dag_authorized(self, mock_get_auth_manager):
auth_manager = Mock()
auth_manager.is_authorized_dag.return_value = True
mock_get_auth_manager.return_value = auth_manager
+ fastapi_request = Mock()
+ fastapi_request.path_params.return_value = {}
- mock_request = Mock()
- mock_request.path_params.return_value = {"dag_id": "test"}
-
- requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock())
+ requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock())
auth_manager.is_authorized_dag.assert_called_once()
@@ -101,11 +100,13 @@ def test_requires_access_dag_unauthorized(self, mock_get_auth_manager):
auth_manager = Mock()
auth_manager.is_authorized_dag.return_value = False
mock_get_auth_manager.return_value = auth_manager
+ fastapi_request = Mock()
+ fastapi_request.path_params.return_value = {}
mock_request = Mock()
mock_request.path_params.return_value = {}
with pytest.raises(HTTPException, match="Forbidden"):
- requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock())
+ requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock())
auth_manager.is_authorized_dag.assert_called_once()