Skip to content

Commit

Permalink
make paginated_select kw-only (#44239)
Browse files Browse the repository at this point in the history
* make paginated_select kw-only

* make paginated_select kw-only

* rename base_select to select

* rename base_select to select
  • Loading branch information
rawwar authored Nov 21, 2024
1 parent 9fb3d07 commit a924284
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 76 deletions.
5 changes: 3 additions & 2 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | N

@provide_session
def paginated_select(
base_select: Select,
*,
select: Select,
filters: Sequence[BaseParam],
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
Expand All @@ -68,7 +69,7 @@ def paginated_select(
return_total_entries: bool = True,
) -> Select:
base_select = apply_filters_to_select(
base_select,
select,
filters,
)

Expand Down
14 changes: 4 additions & 10 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_assets(
) -> AssetCollectionResponse:
"""Get assets."""
assets_select, total_entries = paginated_select(
select(AssetModel),
select=select(AssetModel),
filters=[uri_pattern, dag_ids],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -144,7 +144,7 @@ def get_asset_events(
) -> AssetEventCollectionResponse:
"""Get asset events."""
assets_event_select, total_entries = paginated_select(
select(AssetEvent),
select=select(AssetEvent),
filters=[asset_id, source_dag_id, source_task_id, source_run_id, source_map_index],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -211,10 +211,7 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down Expand Up @@ -273,10 +270,7 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(
query,
[],
)
dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def list_backfills(
session: Annotated[Session, Depends(get_session)],
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select(Backfill).where(Backfill.dag_id == dag_id),
[],
select=select(Backfill).where(Backfill.dag_id == dag_id),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def get_connections(
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select(Connection),
[],
select=select(Connection),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
12 changes: 6 additions & 6 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ def get_dag_runs(
base_query = base_query.filter(DagRun.dag_id == dag_id)

dag_run_select, total_entries = paginated_select(
base_query,
[logical_date, start_date_range, end_date_range, update_at_range, state],
order_by,
offset,
limit,
session,
select=base_query,
filters=[logical_date, start_date_range, end_date_range, update_at_range, state],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

dag_runs = session.scalars(dag_run_select)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_dag_stats(
) -> DagStatsCollectionResponse:
"""Get Dag statistics."""
dagruns_select, _ = paginated_select(
base_select=dagruns_select_with_state_count,
select=dagruns_select_with_state_count,
filters=[dag_ids],
session=session,
return_total_entries=False,
Expand Down
7 changes: 6 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def list_dag_warnings(
) -> DAGWarningCollectionResponse:
"""Get a list of DAG warnings."""
dag_warnings_select, total_entries = paginated_select(
select(DagWarning), [warning_type, dag_id], order_by, offset, limit, session
select=select(DagWarning),
filters=[warning_type, dag_id],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

dag_warnings = session.scalars(dag_warnings_select)
Expand Down
34 changes: 21 additions & 13 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,20 @@ def get_dags(
) -> DAGCollectionResponse:
"""Get all DAGs."""
dags_select, total_entries = paginated_select(
dags_select_with_latest_dag_run,
[only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state],
order_by,
offset,
limit,
session,
select=dags_select_with_latest_dag_run,
filters=[
only_active,
paused,
dag_id_pattern,
dag_display_name_pattern,
tags,
owners,
last_dag_run_state,
],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

dags = session.scalars(dags_select)
Expand Down Expand Up @@ -119,7 +127,7 @@ def get_dag_tags(
"""Get all DAG tags."""
base_select = select(DagTag.name).group_by(DagTag.name)
dag_tags_select, total_entries = paginated_select(
base_select=base_select,
select=base_select,
filters=[tag_name_pattern],
order_by=order_by,
offset=offset,
Expand Down Expand Up @@ -254,12 +262,12 @@ def patch_dags(
update_mask = ["is_paused"]

dags_select, total_entries = paginated_select(
dags_select_with_latest_dag_run,
[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state],
None,
offset,
limit,
session,
select=dags_select_with_latest_dag_run,
filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state],
order_by=None,
offset=offset,
limit=limit,
session=session,
)

dags = session.scalars(dags_select).all()
Expand Down
12 changes: 6 additions & 6 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def get_event_logs(
if after is not None:
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
base_select,
[],
order_by,
offset,
limit,
session,
select=base_select,
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
event_logs = session.scalars(event_logs_select)

Expand Down
12 changes: 6 additions & 6 deletions airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def get_import_errors(
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select(ParseImportError),
[],
order_by,
offset,
limit,
session,
select=select(ParseImportError),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
import_errors = session.scalars(import_errors_select)

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def get_pools(
) -> PoolCollectionResponse:
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select(Pool),
[],
select=select(Pool),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
36 changes: 18 additions & 18 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def get_mapped_task_instances(
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)

task_instance_select, total_entries = paginated_select(
base_query,
[
select=base_query,
filters=[
logical_date_range,
start_date_range,
end_date_range,
Expand All @@ -164,10 +164,10 @@ def get_mapped_task_instances(
queue,
executor,
],
order_by,
offset,
limit,
session,
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)
Expand Down Expand Up @@ -318,8 +318,8 @@ def get_task_instances(
base_query = base_query.where(TI.run_id == dag_run_id)

task_instance_select, total_entries = paginated_select(
base_query,
[
select=base_query,
filters=[
logical_date,
start_date_range,
end_date_range,
Expand All @@ -330,10 +330,10 @@ def get_task_instances(
queue,
executor,
],
order_by,
offset,
limit,
session,
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)
Expand Down Expand Up @@ -392,8 +392,8 @@ def get_task_instances_batch(

base_query = select(TI).join(TI.dag_run)
task_instance_select, total_entries = paginated_select(
base_query,
[
select=base_query,
filters=[
dag_ids,
dag_run_ids,
task_ids,
Expand All @@ -406,10 +406,10 @@ def get_task_instances_batch(
queue,
executor,
],
order_by,
offset,
limit,
session,
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

task_instance_select = task_instance_select.options(
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def get_variables(
) -> VariableCollectionResponse:
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
select(Variable),
[],
select=select(Variable),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
18 changes: 13 additions & 5 deletions airflow/api_fastapi/core_api/routes/ui/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,19 @@ def recent_dag_runs(
.order_by(recent_runs_subquery.c.logical_date.desc())
)
dags_with_recent_dag_runs_select_filter, _ = paginated_select(
dags_with_recent_dag_runs_select,
[only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state],
None,
offset,
limit,
select=dags_with_recent_dag_runs_select,
filters=[
only_active,
paused,
dag_id_pattern,
dag_display_name_pattern,
tags,
owners,
last_dag_run_state,
],
order_by=None,
offset=offset,
limit=limit,
)
dags_with_recent_dag_runs = session.execute(dags_with_recent_dag_runs_select_filter)
# aggregate rows by dag_id
Expand Down

0 comments on commit a924284

Please sign in to comment.