From a9242844706ca117f86d22092109939dd56435ee Mon Sep 17 00:00:00 2001 From: Kalyan R Date: Thu, 21 Nov 2024 19:24:41 +0530 Subject: [PATCH] make paginated_select kw-only (#44239) * make paginated_select kw-only * make paginated_select kw-only * rename base_select to select * rename base_select to select --- airflow/api_fastapi/common/db/common.py | 5 +-- .../core_api/routes/public/assets.py | 14 +++----- .../core_api/routes/public/backfills.py | 4 +-- .../core_api/routes/public/connections.py | 4 +-- .../core_api/routes/public/dag_run.py | 12 +++---- .../core_api/routes/public/dag_stats.py | 2 +- .../core_api/routes/public/dag_warning.py | 7 +++- .../core_api/routes/public/dags.py | 34 +++++++++++------- .../core_api/routes/public/event_logs.py | 12 +++---- .../core_api/routes/public/import_error.py | 12 +++---- .../core_api/routes/public/pools.py | 4 +-- .../core_api/routes/public/task_instances.py | 36 +++++++++---------- .../core_api/routes/public/variables.py | 4 +-- .../api_fastapi/core_api/routes/ui/dags.py | 18 +++++++--- 14 files changed, 92 insertions(+), 76 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index b4462bda420c9..e083cf650fd8d 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -59,7 +59,8 @@ def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | N @provide_session def paginated_select( - base_select: Select, + *, + select: Select, filters: Sequence[BaseParam], order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -68,7 +69,7 @@ def paginated_select( return_total_entries: bool = True, ) -> Select: base_select = apply_filters_to_select( - base_select, + select, filters, ) diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 785f76c9662ba..7da7fcb8e878d 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -95,7 +95,7 @@ def get_assets( ) -> AssetCollectionResponse: """Get assets.""" assets_select, total_entries = paginated_select( - select(AssetModel), + select=select(AssetModel), filters=[uri_pattern, dag_ids], order_by=order_by, offset=offset, @@ -144,7 +144,7 @@ def get_asset_events( ) -> AssetEventCollectionResponse: """Get asset events.""" assets_event_select, total_entries = paginated_select( - select(AssetEvent), + select=select(AssetEvent), filters=[asset_id, source_dag_id, source_task_id, source_run_id, source_map_index], order_by=order_by, offset=offset, @@ -211,10 +211,7 @@ def get_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select( - query, - [], - ) + dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[]) adrqs = session.execute(dag_asset_queued_events_select).all() if not adrqs: @@ -273,10 +270,7 @@ def get_dag_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select( - query, - [], - ) + dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[]) adrqs = session.execute(dag_asset_queued_events_select).all() if not adrqs: diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 5adddc12a4d15..c4ab7ce16b603 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -60,8 +60,8 @@ def list_backfills( session: Annotated[Session, Depends(get_session)], ) -> BackfillCollectionResponse: select_stmt, total_entries = paginated_select( - select(Backfill).where(Backfill.dag_id == dag_id), - [], + select=select(Backfill).where(Backfill.dag_id == dag_id), + filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index c79174bd2cca7..1ca158bad5dac 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -100,8 +100,8 @@ def get_connections( ) -> ConnectionCollectionResponse: """Get all connection entries.""" connection_select, total_entries = paginated_select( - select(Connection), - [], + select=select(Connection), + filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 6ce60fe896d1c..18a3128a8047e 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -288,12 +288,12 @@ def get_dag_runs( base_query = base_query.filter(DagRun.dag_id == dag_id) dag_run_select, total_entries = paginated_select( - base_query, - [logical_date, start_date_range, end_date_range, update_at_range, state], - order_by, - offset, - limit, - session, + select=base_query, + filters=[logical_date, start_date_range, end_date_range, update_at_range, state], + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) dag_runs = session.scalars(dag_run_select) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_stats.py b/airflow/api_fastapi/core_api/routes/public/dag_stats.py index 49c5e7ec48f94..119961f8c5f36 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_stats.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_stats.py @@ -55,7 +55,7 @@ def get_dag_stats( ) -> DagStatsCollectionResponse: """Get Dag statistics.""" dagruns_select, _ = paginated_select( - base_select=dagruns_select_with_state_count, + select=dagruns_select_with_state_count, filters=[dag_ids], session=session, return_total_entries=False, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index 2213df9bdc425..3560874ad1aa9 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -60,7 +60,12 @@ def list_dag_warnings( ) -> DAGWarningCollectionResponse: """Get a list of DAG warnings.""" dag_warnings_select, total_entries = paginated_select( - select(DagWarning), [warning_type, dag_id], order_by, offset, limit, session + select=select(DagWarning), + filters=[warning_type, dag_id], + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) dag_warnings = session.scalars(dag_warnings_select) diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index f44f2f7e8a5f3..0383584fe1729 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -82,12 +82,20 @@ def get_dags( ) -> DAGCollectionResponse: """Get all DAGs.""" dags_select, total_entries = paginated_select( - dags_select_with_latest_dag_run, - [only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state], - order_by, - offset, - limit, - session, + select=dags_select_with_latest_dag_run, + filters=[ + only_active, + paused, + dag_id_pattern, + dag_display_name_pattern, + tags, + owners, + last_dag_run_state, + ], + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) dags = session.scalars(dags_select) @@ -119,7 +127,7 @@ def get_dag_tags( """Get all DAG tags.""" base_select = select(DagTag.name).group_by(DagTag.name) dag_tags_select, total_entries = paginated_select( - base_select=base_select, + select=base_select, filters=[tag_name_pattern], order_by=order_by, offset=offset, @@ -254,12 +262,12 @@ def patch_dags( update_mask = ["is_paused"] dags_select, total_entries = paginated_select( - dags_select_with_latest_dag_run, - [only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], - None, - offset, - limit, - session, + select=dags_select_with_latest_dag_run, + filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], + order_by=None, + offset=offset, + limit=limit, + session=session, ) dags = session.scalars(dags_select).all() diff --git a/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow/api_fastapi/core_api/routes/public/event_logs.py index abe0371a735b7..166b48995399a 100644 --- a/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -125,12 +125,12 @@ def get_event_logs( if after is not None: base_select = base_select.where(Log.dttm > after) event_logs_select, total_entries = paginated_select( - base_select, - [], - order_by, - offset, - limit, - session, + select=base_select, + filters=[], + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) event_logs = session.scalars(event_logs_select) diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index 3db8cf178f049..29f783081d78c 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -89,12 +89,12 @@ def get_import_errors( ) -> ImportErrorCollectionResponse: """Get all import errors.""" import_errors_select, total_entries = paginated_select( - select(ParseImportError), - [], - order_by, - offset, - limit, - session, + select=select(ParseImportError), + filters=[], + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) import_errors = session.scalars(import_errors_select) diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 79b055b07447f..df14b9aae5a05 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -95,8 +95,8 @@ def get_pools( ) -> PoolCollectionResponse: """Get all pools entries.""" pools_select, total_entries = paginated_select( - select(Pool), - [], + select=select(Pool), + filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index f4769a981b882..6d5b427abb1dc 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -152,8 +152,8 @@ def get_mapped_task_instances( raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) task_instance_select, total_entries = paginated_select( - base_query, - [ + select=base_query, + filters=[ logical_date_range, start_date_range, end_date_range, @@ -164,10 +164,10 @@ def get_mapped_task_instances( queue, executor, ], - order_by, - offset, - limit, - session, + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) task_instances = session.scalars(task_instance_select) @@ -318,8 +318,8 @@ def get_task_instances( base_query = base_query.where(TI.run_id == dag_run_id) task_instance_select, total_entries = paginated_select( - base_query, - [ + select=base_query, + filters=[ logical_date, start_date_range, end_date_range, @@ -330,10 +330,10 @@ def get_task_instances( queue, executor, ], - order_by, - offset, - limit, - session, + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) task_instances = session.scalars(task_instance_select) @@ -392,8 +392,8 @@ def get_task_instances_batch( base_query = select(TI).join(TI.dag_run) task_instance_select, total_entries = paginated_select( - base_query, - [ + select=base_query, + filters=[ dag_ids, dag_run_ids, task_ids, @@ -406,10 +406,10 @@ def get_task_instances_batch( queue, executor, ], - order_by, - offset, - limit, - session, + order_by=order_by, + offset=offset, + limit=limit, + session=session, ) task_instance_select = task_instance_select.options( diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index fe966deda0b6b..a9e479f4f853a 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -90,8 +90,8 @@ def get_variables( ) -> VariableCollectionResponse: """Get all Variables entries.""" variable_select, total_entries = paginated_select( - select(Variable), - [], + select=select(Variable), + filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/ui/dags.py b/airflow/api_fastapi/core_api/routes/ui/dags.py index fad736ced379c..96b8b0c1b109f 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dags.py +++ b/airflow/api_fastapi/core_api/routes/ui/dags.py @@ -103,11 +103,19 @@ def recent_dag_runs( .order_by(recent_runs_subquery.c.logical_date.desc()) ) dags_with_recent_dag_runs_select_filter, _ = paginated_select( - dags_with_recent_dag_runs_select, - [only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state], - None, - offset, - limit, + select=dags_with_recent_dag_runs_select, + filters=[ + only_active, + paused, + dag_id_pattern, + dag_display_name_pattern, + tags, + owners, + last_dag_run_state, + ], + order_by=None, + offset=offset, + limit=limit, ) dags_with_recent_dag_runs = session.execute(dags_with_recent_dag_runs_select_filter) # aggregate rows by dag_id