Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
if TYPE_CHECKING:
from sqlalchemy.sql import Select

from airflow.api_fastapi.common.parameters import BaseParam
from airflow.api_fastapi.core_api.base import OrmClause


def _get_session() -> Session:
Expand All @@ -47,7 +47,7 @@ def _get_session() -> Session:


def apply_filters_to_select(
*, statement: Select, filters: Sequence[BaseParam | None] | None = None
*, statement: Select, filters: Sequence[OrmClause | None] | None = None
) -> Select:
if filters is None:
return statement
Expand All @@ -71,10 +71,10 @@ async def _get_async_session() -> AsyncSession:
async def paginated_select_async(
*,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...
Expand All @@ -84,10 +84,10 @@ async def paginated_select_async(
async def paginated_select_async(
*,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...
Expand All @@ -96,10 +96,10 @@ async def paginated_select_async(
async def paginated_select_async(
*,
statement: Select,
filters: Sequence[BaseParam | None] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause | None] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
Expand Down Expand Up @@ -129,10 +129,10 @@ async def paginated_select_async(
def paginated_select(
*,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...
Expand All @@ -142,10 +142,10 @@ def paginated_select(
def paginated_select(
*,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...
Expand All @@ -155,10 +155,10 @@ def paginated_select(
def paginated_select(
*,
statement: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
filters: Sequence[OrmClause] | None = None,
order_by: OrmClause | None = None,
offset: OrmClause | None = None,
limit: OrmClause | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
Expand Down
11 changes: 4 additions & 7 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sqlalchemy import Column, and_, case, or_
from sqlalchemy.inspection import inspect

from airflow.api_fastapi.core_api.base import OrmClause
from airflow.models import Base
from airflow.models.asset import (
AssetAliasModel,
Expand All @@ -65,18 +66,14 @@
T = TypeVar("T")


class BaseParam(Generic[T], ABC):
"""Base class for filters."""
class BaseParam(OrmClause[T], ABC):
"""Base class for path or query parameters with ORM transformation."""

def __init__(self, value: T | None = None, skip_none: bool = True) -> None:
self.value = value
super().__init__(value)
self.attribute: ColumnElement | None = None
self.skip_none = skip_none

@abstractmethod
def to_orm(self, select: Select) -> Select:
pass

def set_value(self, value: T | None) -> Self:
self.value = value
return self
Expand Down
23 changes: 23 additions & 0 deletions airflow/api_fastapi/core_api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@
# under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar

from pydantic import BaseModel as PydanticBaseModel, ConfigDict

if TYPE_CHECKING:
from sqlalchemy.sql import Select

T = TypeVar("T")


class BaseModel(PydanticBaseModel):
"""
Expand All @@ -39,3 +47,18 @@ class StrictBaseModel(BaseModel):
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid")


class OrmClause(Generic[T], ABC):
"""
Base class for filtering clauses with paginated_select.

The subclasses should implement the `to_orm` method and set the `value` attribute.
"""

def __init__(self, value: T | None = None):
self.value = value

@abstractmethod
def to_orm(self, select: Select) -> Select:
pass
12 changes: 12 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3058,6 +3058,8 @@ paths:
summary: Get Dags
description: Get all DAGs.
operationId: get_dags
security:
- OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
Expand Down Expand Up @@ -3223,6 +3225,8 @@ paths:
summary: Patch Dags
description: Patch multiple DAGs.
operationId: patch_dags
security:
- OAuth2PasswordBearer: []
parameters:
- name: update_mask
in: query
Expand Down Expand Up @@ -3358,6 +3362,8 @@ paths:
summary: Get Dag
description: Get basic information about a DAG.
operationId: get_dag
security:
- OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
Expand Down Expand Up @@ -3408,6 +3414,8 @@ paths:
summary: Patch Dag
description: Patch the specific DAG.
operationId: patch_dag
security:
- OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
Expand Down Expand Up @@ -3474,6 +3482,8 @@ paths:
summary: Delete Dag
description: Delete the specific DAG.
operationId: delete_dag
security:
- OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
Expand Down Expand Up @@ -3524,6 +3534,8 @@ paths:
summary: Get Dag Details
description: Get details of DAG.
operationId: get_dag_details
security:
- OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
Expand Down
20 changes: 15 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
DAGResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import (
EditableDagsFilterDep,
ReadableDagsFilterDep,
requires_access_dag,
)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DAG, DagModel
Expand All @@ -65,7 +70,7 @@
dags_router = AirflowRouter(tags=["DAG"], prefix="/dags")


@dags_router.get("")
@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))])
def get_dags(
limit: QueryLimit,
offset: QueryOffset,
Expand Down Expand Up @@ -105,6 +110,7 @@ def get_dags(
).dynamic_depends()
),
],
readable_dags_filter: ReadableDagsFilterDep,
session: SessionDep,
) -> DAGCollectionResponse:
"""Get all DAGs."""
Expand Down Expand Up @@ -132,6 +138,7 @@ def get_dags(
tags,
owners,
last_dag_run_state,
readable_dags_filter,
],
order_by=order_by,
offset=offset,
Expand All @@ -156,6 +163,7 @@ def get_dags(
status.HTTP_422_UNPROCESSABLE_ENTITY,
]
),
dependencies=[Depends(requires_access_dag(method="GET"))],
)
def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse:
"""Get basic information about a DAG."""
Expand All @@ -182,6 +190,7 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse:
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_dag(method="GET"))],
)
def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse:
"""Get details of DAG."""
Expand All @@ -208,7 +217,7 @@ def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDe
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(action_logging())],
dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())],
)
def patch_dag(
dag_id: str,
Expand Down Expand Up @@ -251,7 +260,7 @@ def patch_dag(
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(action_logging())],
dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())],
)
def patch_dags(
patch_body: DAGPatchBody,
Expand All @@ -263,6 +272,7 @@ def patch_dags(
only_active: QueryOnlyActiveFilter,
paused: QueryPausedFilter,
last_dag_run_state: QueryLastDagRunStateFilter,
editable_dags_filter: EditableDagsFilterDep,
session: SessionDep,
update_mask: list[str] | None = Query(None),
) -> DAGCollectionResponse:
Expand All @@ -283,7 +293,7 @@ def patch_dags(

dags_select, total_entries = paginated_select(
statement=generate_dag_with_latest_run_query(),
filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state],
filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state, editable_dags_filter],
order_by=None,
offset=offset,
limit=limit,
Expand Down Expand Up @@ -313,7 +323,7 @@ def patch_dags(
status.HTTP_422_UNPROCESSABLE_ENTITY,
]
),
dependencies=[Depends(action_logging())],
dependencies=[Depends(requires_access_dag(method="DELETE")), Depends(action_logging())],
)
def delete_dag(
dag_id: str,
Expand Down
Loading