Skip to content

Commit

Permalink
Make filters param optional and fix typing
Browse files Browse the repository at this point in the history
Given that sometimes we don't want to apply any filters, it makes sense to make the param optional.  I also fix the typing on `paginated_select`.
  • Loading branch information
dstandish committed Nov 20, 2024
1 parent d43052e commit 697edf0
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 40 deletions.
30 changes: 20 additions & 10 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Database helpers for Airflow REST API.
:meta private:
"""

from __future__ import annotations

Expand Down Expand Up @@ -47,40 +52,45 @@ def your_route(session: Annotated[Session, Depends(get_session)]):
yield session


def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select:
base_select = base_select
for filter in filters:
if filter is None:
def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
if filters is None:
return base_select
for f in filters:
if f is None:
continue
base_select = filter.to_orm(base_select)
base_select = f.to_orm(base_select)

return base_select


@provide_session
def paginated_select(
base_select: Select,
filters: Sequence[BaseParam],
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> Select:
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
base_select,
filters,
base_select=base_select,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = get_query_count(base_select, session=session)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select, [order_by, offset, limit])
base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])

return base_select, total_entries
22 changes: 13 additions & 9 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import delete, select
Expand Down Expand Up @@ -102,6 +102,9 @@ def get_assets(
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

assets = session.scalars(
assets_select.options(
subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks)
Expand Down Expand Up @@ -151,6 +154,8 @@ def get_asset_events(
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

assets_event_select = assets_event_select.options(subqueryload(AssetEvent.created_dagruns))
assets_events = session.scalars(assets_event_select)
Expand Down Expand Up @@ -211,11 +216,10 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(query)
adrqs = session.execute(dag_asset_queued_events_select).all()
if TYPE_CHECKING:
assert isinstance(total_entries, int)

if not adrqs:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with uri: `{uri}` was not found")
Expand Down Expand Up @@ -273,10 +277,10 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(query)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def list_backfills(
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select(Backfill).where(Backfill.dag_id == dag_id),
[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import os
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -101,12 +101,13 @@ def get_connections(
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select(Connection),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

connections = session.scalars(connection_select)

Expand Down
14 changes: 5 additions & 9 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -126,21 +126,17 @@ def get_event_logs(
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
base_select,
[],
None,
order_by,
offset,
limit,
session,
)
event_logs = session.scalars(event_logs_select)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

return EventLogCollectionResponse(
event_logs=[
EventLogResponse.model_validate(
event_log,
from_attributes=True,
)
for event_log in event_logs
],
event_logs=[EventLogResponse.model_validate(x, from_attributes=True) for x in event_logs],
total_entries=total_entries,
)
10 changes: 5 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import select
Expand Down Expand Up @@ -90,17 +90,17 @@ def get_import_errors(
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select(ParseImportError),
[],
None,
order_by,
offset,
limit,
session,
)
import_errors = session.scalars(import_errors_select)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

return ImportErrorCollectionResponse(
import_errors=[
ImportErrorResponse.model_validate(error, from_attributes=True) for error in import_errors
],
import_errors=[ImportErrorResponse.model_validate(x, from_attributes=True) for x in import_errors],
total_entries=total_entries,
)
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from fastapi.exceptions import RequestValidationError
Expand Down Expand Up @@ -96,12 +96,13 @@ def get_pools(
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select(Pool),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

pools = session.scalars(pools_select)

Expand Down
5 changes: 3 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
Expand Down Expand Up @@ -91,12 +91,13 @@ def get_variables(
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
select(Variable),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
if TYPE_CHECKING:
assert isinstance(total_entries, int)

variables = session.scalars(variable_select)

Expand Down

0 comments on commit 697edf0

Please sign in to comment.