From f611ffc4261658dc57e3f8d41ff47ba1ed2c0c97 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:14:27 -0800 Subject: [PATCH 1/6] Add base model for rest api; set from_attributes=True This enables us to have simpler API code, since we don't need to call `model_validate` as much. --- airflow/api_fastapi/core_api/base.py | 29 +++++++++++++++++++ .../api_fastapi/core_api/datamodels/assets.py | 3 +- .../core_api/datamodels/backfills.py | 3 +- .../api_fastapi/core_api/datamodels/config.py | 2 +- .../core_api/datamodels/connections.py | 3 +- .../core_api/datamodels/dag_run.py | 3 +- .../core_api/datamodels/dag_sources.py | 2 +- .../core_api/datamodels/dag_stats.py | 3 +- .../core_api/datamodels/dag_warning.py | 3 +- .../api_fastapi/core_api/datamodels/dags.py | 2 +- .../core_api/datamodels/event_logs.py | 4 ++- .../core_api/datamodels/import_error.py | 4 ++- .../api_fastapi/core_api/datamodels/job.py | 4 ++- .../core_api/datamodels/monitor.py | 2 +- .../core_api/datamodels/plugins.py | 3 +- .../api_fastapi/core_api/datamodels/pools.py | 4 ++- .../core_api/datamodels/providers.py | 2 +- .../core_api/datamodels/task_instances.py | 2 +- .../api_fastapi/core_api/datamodels/tasks.py | 3 +- .../core_api/datamodels/trigger.py | 4 ++- .../core_api/datamodels/variables.py | 3 +- .../core_api/datamodels/version.py | 2 +- .../api_fastapi/core_api/datamodels/xcom.py | 4 ++- .../core_api/routes/public/backfills.py | 4 +-- 24 files changed, 71 insertions(+), 27 deletions(-) create mode 100644 airflow/api_fastapi/core_api/base.py diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py new file mode 100644 index 000000000000..52df0e6fea5d --- /dev/null +++ b/airflow/api_fastapi/core_api/base.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class BaseModel(PydanticBaseModel): + """ + Base pydantic model for REST API. + + :meta private: + """ + + model_config = ConfigDict(from_attributes=True) diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index 94ec17ad63d9..adc32c2e4808 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -19,8 +19,9 @@ from datetime import datetime -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.log.secrets_masker import redact diff --git a/airflow/api_fastapi/core_api/datamodels/backfills.py b/airflow/api_fastapi/core_api/datamodels/backfills.py index 69d6a98ccfd1..be04063907a9 100644 --- a/airflow/api_fastapi/core_api/datamodels/backfills.py +++ b/airflow/api_fastapi/core_api/datamodels/backfills.py @@ -19,8 +19,7 @@ from datetime import datetime -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.backfill import ReprocessBehavior diff --git a/airflow/api_fastapi/core_api/datamodels/config.py b/airflow/api_fastapi/core_api/datamodels/config.py index 0627832e45f4..c16aa98093fb 100644 --- a/airflow/api_fastapi/core_api/datamodels/config.py +++ b/airflow/api_fastapi/core_api/datamodels/config.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class ConfigOption(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow/api_fastapi/core_api/datamodels/connections.py index 7b23682cc8ef..d74ced1ba4d3 100644 --- a/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow/api_fastapi/core_api/datamodels/connections.py @@ -19,9 +19,10 @@ import json -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator from pydantic_core.core_schema import ValidationInfo +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.log.secrets_masker import redact diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow/api_fastapi/core_api/datamodels/dag_run.py index f3343e6c407d..d211b0205b3b 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -20,8 +20,9 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel, Field +from pydantic import Field +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType diff --git a/airflow/api_fastapi/core_api/datamodels/dag_sources.py b/airflow/api_fastapi/core_api/datamodels/dag_sources.py index 4b84bd1e6b1c..6db3f334b805 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_sources.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_sources.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class DAGSourceResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/dag_stats.py b/airflow/api_fastapi/core_api/datamodels/dag_stats.py index 0d768c2cbac0..1effdd5a94f7 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_stats.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_stats.py @@ -17,8 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import DagRunState diff --git a/airflow/api_fastapi/core_api/datamodels/dag_warning.py b/airflow/api_fastapi/core_api/datamodels/dag_warning.py index f38a3a8d093f..f1dbf1411916 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_warning.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_warning.py @@ -19,8 +19,7 @@ from datetime import datetime -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.dagwarning import DagWarningType diff --git a/airflow/api_fastapi/core_api/datamodels/dags.py b/airflow/api_fastapi/core_api/datamodels/dags.py index 27cc3ad47356..9f2e764cec0b 100644 --- a/airflow/api_fastapi/core_api/datamodels/dags.py +++ b/airflow/api_fastapi/core_api/datamodels/dags.py @@ -25,12 +25,12 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from pydantic import ( AliasGenerator, - BaseModel, ConfigDict, computed_field, field_validator, ) +from airflow.api_fastapi.core_api.base import BaseModel from airflow.configuration import conf from airflow.serialization.pydantic.dag import DagTagPydantic diff --git a/airflow/api_fastapi/core_api/datamodels/event_logs.py b/airflow/api_fastapi/core_api/datamodels/event_logs.py index 5b65ec85ba7b..1d1f8d165cd1 100644 --- a/airflow/api_fastapi/core_api/datamodels/event_logs.py +++ b/airflow/api_fastapi/core_api/datamodels/event_logs.py @@ -19,7 +19,9 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel class EventLogResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/import_error.py b/airflow/api_fastapi/core_api/datamodels/import_error.py index ebc65e23eccb..93e496057cab 100644 --- a/airflow/api_fastapi/core_api/datamodels/import_error.py +++ b/airflow/api_fastapi/core_api/datamodels/import_error.py @@ -18,7 +18,9 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel class ImportErrorResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/job.py b/airflow/api_fastapi/core_api/datamodels/job.py index e4d5ceb4b4e2..883622a67bfa 100644 --- a/airflow/api_fastapi/core_api/datamodels/job.py +++ b/airflow/api_fastapi/core_api/datamodels/job.py @@ -18,7 +18,9 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict + +from airflow.api_fastapi.core_api.base import BaseModel class JobResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/monitor.py b/airflow/api_fastapi/core_api/datamodels/monitor.py index 0734321a45fd..fbaf40b4e841 100644 --- a/airflow/api_fastapi/core_api/datamodels/monitor.py +++ b/airflow/api_fastapi/core_api/datamodels/monitor.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class BaseInfoSchema(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/plugins.py b/airflow/api_fastapi/core_api/datamodels/plugins.py index cc305ed3aa88..798ba6fa85d3 100644 --- a/airflow/api_fastapi/core_api/datamodels/plugins.py +++ b/airflow/api_fastapi/core_api/datamodels/plugins.py @@ -19,8 +19,9 @@ from typing import Annotated, Any -from pydantic import BaseModel, BeforeValidator, ConfigDict, field_validator +from pydantic import BeforeValidator, ConfigDict, field_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.plugins_manager import AirflowPluginSource diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index ef3676a8afec..762ad4a819bd 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,9 @@ from typing import Annotated, Callable -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field +from pydantic import BeforeValidator, ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel def _call_function(function: Callable[[], int]) -> int: diff --git a/airflow/api_fastapi/core_api/datamodels/providers.py b/airflow/api_fastapi/core_api/datamodels/providers.py index 4e542f19f9f8..8b515fafd2da 100644 --- a/airflow/api_fastapi/core_api/datamodels/providers.py +++ b/airflow/api_fastapi/core_api/datamodels/providers.py @@ -17,7 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class ProviderResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 4712df3273a4..9a04e8831b13 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -22,13 +22,13 @@ from pydantic import ( AliasPath, AwareDatetime, - BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, ) +from airflow.api_fastapi.core_api.base import BaseModel from airflow.api_fastapi.core_api.datamodels.job import JobResponse from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse from airflow.utils.state import TaskInstanceState diff --git a/airflow/api_fastapi/core_api/datamodels/tasks.py b/airflow/api_fastapi/core_api/datamodels/tasks.py index 9b962390cc34..0806d4453c49 100644 --- a/airflow/api_fastapi/core_api/datamodels/tasks.py +++ b/airflow/api_fastapi/core_api/datamodels/tasks.py @@ -22,9 +22,10 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, computed_field, field_validator, model_validator +from pydantic import computed_field, field_validator, model_validator from airflow.api_fastapi.common.types import TimeDeltaWithValidation +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy from airflow.task.priority_strategy import PriorityWeightStrategy diff --git a/airflow/api_fastapi/core_api/datamodels/trigger.py b/airflow/api_fastapi/core_api/datamodels/trigger.py index eb9be97d3140..265d40ff19bf 100644 --- a/airflow/api_fastapi/core_api/datamodels/trigger.py +++ b/airflow/api_fastapi/core_api/datamodels/trigger.py @@ -19,7 +19,9 @@ from datetime import datetime from typing import Annotated -from pydantic import BaseModel, BeforeValidator, ConfigDict +from pydantic import BeforeValidator, ConfigDict + +from airflow.api_fastapi.core_api.base import BaseModel class TriggerResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow/api_fastapi/core_api/datamodels/variables.py index 9a2ce996d3a4..ff9b0209278f 100644 --- a/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow/api_fastapi/core_api/datamodels/variables.py @@ -19,8 +19,9 @@ import json -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.typing_compat import Self from airflow.utils.log.secrets_masker import redact diff --git a/airflow/api_fastapi/core_api/datamodels/version.py b/airflow/api_fastapi/core_api/datamodels/version.py index 01c4c45376f7..b29864776c6f 100644 --- a/airflow/api_fastapi/core_api/datamodels/version.py +++ b/airflow/api_fastapi/core_api/datamodels/version.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class VersionInfo(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/xcom.py b/airflow/api_fastapi/core_api/datamodels/xcom.py index 186b5aad77f0..370aa651cb2c 100644 --- a/airflow/api_fastapi/core_api/datamodels/xcom.py +++ b/airflow/api_fastapi/core_api/datamodels/xcom.py @@ -19,7 +19,9 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, field_validator +from pydantic import field_validator + +from airflow.api_fastapi.core_api.base import BaseModel class XComResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 9c5dd0895c47..29264bad540f 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -70,7 +70,7 @@ def list_backfills( backfills = session.scalars(select_stmt) return BackfillCollectionResponse( - backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills], + backfills=backfills, total_entries=total_entries, ) @@ -85,7 +85,7 @@ def get_backfill( ) -> BackfillResponse: backfill = session.get(Backfill, backfill_id) if backfill: - return BackfillResponse.model_validate(backfill, from_attributes=True) + return BackfillResponse.model_validate(backfill) raise HTTPException(status.HTTP_404_NOT_FOUND, "Backfill not found") From f8d44ab731ccb296137b054733acd7c8b3d81ca2 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:25:37 -0800 Subject: [PATCH 2/6] remove the explicit model_validate calls --- .../api_fastapi/core_api/datamodels/dags.py | 2 +- .../core_api/routes/public/assets.py | 15 +++++---------- .../core_api/routes/public/backfills.py | 8 ++++---- .../core_api/routes/public/connections.py | 14 +++++--------- .../core_api/routes/public/dag_run.py | 18 +++++++----------- .../core_api/routes/public/dag_warning.py | 3 +-- .../api_fastapi/core_api/routes/public/dags.py | 10 +++++----- .../core_api/routes/public/event_logs.py | 2 +- .../core_api/routes/public/import_error.py | 2 +- .../core_api/routes/public/pools.py | 8 ++++---- .../core_api/routes/public/task_instances.py | 13 ++++++------- .../core_api/routes/public/tasks.py | 4 ++-- .../core_api/routes/public/variables.py | 8 ++++---- .../api_fastapi/core_api/routes/public/xcom.py | 4 ++-- airflow/api_fastapi/core_api/routes/ui/dags.py | 4 ++-- .../core_api/routes/ui/dashboard.py | 2 +- .../execution_api/routes/connections.py | 2 +- 17 files changed, 52 insertions(+), 67 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/dags.py b/airflow/api_fastapi/core_api/datamodels/dags.py index 9f2e764cec0b..c4c66fd017d0 100644 --- a/airflow/api_fastapi/core_api/datamodels/dags.py +++ b/airflow/api_fastapi/core_api/datamodels/dags.py @@ -144,7 +144,7 @@ def get_params(cls, params: abc.MutableMapping | None) -> dict | None: """Convert params attribute to dict representation.""" if params is None: return None - return {param_name: param_val.dump() for param_name, param_val in params.items()} + return {k: v.dump() for k, v in params.items()} # Mypy issue https://github.com/python/mypy/issues/1362 @computed_field # type: ignore[misc] diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 64a5acf82604..6d3bbb8086a0 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -109,7 +109,7 @@ def get_assets( ) ) return AssetCollectionResponse( - assets=[AssetResponse.model_validate(asset, from_attributes=True) for asset in assets], + assets=assets, total_entries=total_entries, ) @@ -157,9 +157,7 @@ def get_asset_events( assets_events = session.scalars(assets_event_select) return AssetEventCollectionResponse( - asset_events=[ - AssetEventResponse.model_validate(asset, from_attributes=True) for asset in assets_events - ], + asset_events=[AssetEventResponse.model_validate(asset) for asset in assets_events], total_entries=total_entries, ) @@ -187,7 +185,7 @@ def create_asset_event( if not assets_event: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri: `{body.uri}` was not found") - return AssetEventResponse.model_validate(assets_event, from_attributes=True) + return AssetEventResponse.model_validate(assets_event) @assets_router.get( @@ -247,7 +245,7 @@ def get_asset( if asset is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Asset with uri: `{uri}` was not found") - return AssetResponse.model_validate(asset, from_attributes=True) + return AssetResponse.model_validate(asset) @assets_router.get( @@ -282,10 +280,7 @@ def get_dag_asset_queued_events( ] return QueuedEventCollectionResponse( - queued_events=[ - QueuedEventResponse.model_validate(queued_event, from_attributes=True) - for queued_event in queued_events - ], + queued_events=[QueuedEventResponse.model_validate(queued_event) for queued_event in queued_events], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 29264bad540f..0ce6a90839cc 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -107,7 +107,7 @@ def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session) if b.is_paused is False: b.is_paused = True session.commit() - return BackfillResponse.model_validate(b, from_attributes=True) + return BackfillResponse.model_validate(b) @backfills_router.put( @@ -127,7 +127,7 @@ def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_sessio raise HTTPException(status.HTTP_409_CONFLICT, "Backfill is already completed.") if b.is_paused: b.is_paused = False - return BackfillResponse.model_validate(b, from_attributes=True) + return BackfillResponse.model_validate(b) @backfills_router.put( @@ -172,7 +172,7 @@ def cancel_backfill(backfill_id, session: Annotated[Session, Depends(get_session # this is in separate transaction just to avoid potential conflicts session.refresh(b) b.completed_at = timezone.utcnow() - return BackfillResponse.model_validate(b, from_attributes=True) + return BackfillResponse.model_validate(b) @backfills_router.post( @@ -199,7 +199,7 @@ def create_backfill( dag_run_conf=backfill_request.dag_run_conf, reprocess_behavior=backfill_request.reprocess_behavior, ) - return BackfillResponse.model_validate(backfill_obj, from_attributes=True) + return BackfillResponse.model_validate(backfill_obj) except AlreadyRunningBackfill: raise HTTPException( status_code=status.HTTP_409_CONFLICT, diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index 0716c77f9f4e..d4e0388d4265 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -78,7 +78,7 @@ def get_connection( status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found" ) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection) @connections_router.get( @@ -110,9 +110,7 @@ def get_connections( connections = session.scalars(connection_select) return ConnectionCollectionResponse( - connections=[ - ConnectionResponse.model_validate(connection, from_attributes=True) for connection in connections - ], + connections=connections, total_entries=total_entries, ) @@ -142,7 +140,7 @@ def post_connection( connection = Connection(**post_body.model_dump(by_alias=True)) session.add(connection) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection) @connections_router.patch( @@ -182,7 +180,7 @@ def patch_connection( for key, val in data.items(): setattr(connection, key, val) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection) @connections_router.post( @@ -213,8 +211,6 @@ def test_connection( conn = Connection(**data) os.environ[conn_env_var] = conn.get_uri() test_status, test_message = conn.test_connection() - return ConnectionTestResponse.model_validate( - {"status": test_status, "message": test_message}, from_attributes=True - ) + return ConnectionTestResponse.model_validate({"status": test_status, "message": test_message}) finally: os.environ.pop(conn_env_var, None) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index d7a196eba3f3..70e602ab684c 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -38,7 +38,7 @@ datetime_range_filter_factory, ) from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse, AssetEventResponse +from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse from airflow.api_fastapi.core_api.datamodels.dag_run import ( DAGRunClearBody, DAGRunCollectionResponse, @@ -74,7 +74,7 @@ def get_dag_run( f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found", ) - return DAGRunResponse.model_validate(dag_run, from_attributes=True) + return DAGRunResponse.model_validate(dag_run) @dag_run_router.delete( @@ -156,7 +156,7 @@ def patch_dag_run( dag_run = session.get(DagRun, dag_run.id) - return DAGRunResponse.model_validate(dag_run, from_attributes=True) + return DAGRunResponse.model_validate(dag_run) @dag_run_router.get( @@ -184,9 +184,7 @@ def get_upstream_asset_events( ) events = dag_run.consumed_asset_events return AssetEventCollectionResponse( - asset_events=[ - AssetEventResponse.model_validate(asset_event, from_attributes=True) for asset_event in events - ], + asset_events=events, total_entries=len(events), ) @@ -223,9 +221,7 @@ def clear_dag_run( ) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(ti, from_attributes=True) for ti in task_instances - ], + task_instances=[TaskInstanceResponse.model_validate(ti) for ti in task_instances], total_entries=len(task_instances), ) else: @@ -237,7 +233,7 @@ def clear_dag_run( session=session, ) dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id)) - return DAGRunResponse.model_validate(dag_run_cleared, from_attributes=True) + return DAGRunResponse.model_validate(dag_run_cleared) @dag_run_router.get("", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) @@ -297,6 +293,6 @@ def get_dag_runs( ) dag_runs = session.scalars(dag_run_select) return DAGRunCollectionResponse( - dag_runs=[DAGRunResponse.model_validate(dr, from_attributes=True) for dr in dag_runs], + dag_runs=[DAGRunResponse.model_validate(dr) for dr in dag_runs], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index 9ddd1439b199..e933710bc690 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -37,7 +37,6 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.dag_warning import ( DAGWarningCollectionResponse, - DAGWarningResponse, ) from airflow.models import DagWarning @@ -70,6 +69,6 @@ def list_dag_warnings( dag_warnings = session.scalars(dag_warnings_select) return DAGWarningCollectionResponse( - dag_warnings=[DAGWarningResponse.model_validate(w, from_attributes=True) for w in dag_warnings], + dag_warnings=dag_warnings, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 1416855b11c9..619c55b970ef 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -101,7 +101,7 @@ def get_dags( dags = session.scalars(dags_select) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + dags=dags, total_entries=total_entries, ) @@ -162,7 +162,7 @@ def get_dag(dag_id: str, session: Annotated[Session, Depends(get_session)], requ if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) - return DAGResponse.model_validate(dag_model, from_attributes=True) + return DAGResponse.model_validate(dag_model) @dags_router.get( @@ -190,7 +190,7 @@ def get_dag_details( if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) - return DAGDetailsResponse.model_validate(dag_model, from_attributes=True) + return DAGDetailsResponse.model_validate(dag_model) @dags_router.patch( @@ -227,7 +227,7 @@ def patch_dag( for key, val in data.items(): setattr(dag, key, val) - return DAGResponse.model_validate(dag, from_attributes=True) + return DAGResponse.model_validate(dag) @dags_router.patch( @@ -280,7 +280,7 @@ def patch_dags( ) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(d, from_attributes=True) for d in dags], + dags=[DAGResponse.model_validate(d) for d in dags], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow/api_fastapi/core_api/routes/public/event_logs.py index 7d2933365a95..3c57b62822e9 100644 --- a/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -134,6 +134,6 @@ def get_event_logs( event_logs = session.scalars(event_logs_select) return EventLogCollectionResponse( - event_logs=[EventLogResponse.model_validate(e, from_attributes=True) for e in event_logs], + event_logs=event_logs, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index e17abfbe1ffc..eabf0e224013 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -98,6 +98,6 @@ def get_import_errors( import_errors = session.scalars(import_errors_select) return ImportErrorCollectionResponse( - import_errors=[ImportErrorResponse.model_validate(i, from_attributes=True) for i in import_errors], + import_errors=import_errors, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 582e03ab00db..ec33d29938f6 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -77,7 +77,7 @@ def get_pool( if pool is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Pool with name: `{pool_name}` was not found") - return PoolResponse.model_validate(pool, from_attributes=True) + return PoolResponse.model_validate(pool) @pools_router.get( @@ -105,7 +105,7 @@ def get_pools( pools = session.scalars(pools_select) return PoolCollectionResponse( - pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], + pools=pools, total_entries=total_entries, ) @@ -154,7 +154,7 @@ def patch_pool( for key, value in data.items(): setattr(pool, key, value) - return PoolResponse.model_validate(pool, from_attributes=True) + return PoolResponse.model_validate(pool) @pools_router.post( @@ -170,4 +170,4 @@ def post_pool( session.add(pool) - return PoolResponse.model_validate(pool, from_attributes=True) + return PoolResponse.model_validate(pool) diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index ed6d46dc78f0..d87ea6a0bd48 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -95,7 +95,7 @@ def get_task_instance( status.HTTP_404_NOT_FOUND, "Task instance is mapped, add the map_index value to the URL" ) - return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + return TaskInstanceResponse.model_validate(task_instance) @task_instances_router.get( @@ -173,8 +173,7 @@ def get_mapped_task_instances( return TaskInstanceCollectionResponse( task_instances=[ - TaskInstanceResponse.model_validate(task_instance, from_attributes=True) - for task_instance in task_instances + TaskInstanceResponse.model_validate(task_instance) for task_instance in task_instances ], total_entries=total_entries, ) @@ -260,7 +259,7 @@ def get_mapped_task_instance( f"The Mapped Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}`, and map_index: `{map_index}` was not found", ) - return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + return TaskInstanceResponse.model_validate(task_instance) @task_instances_router.get( @@ -336,7 +335,7 @@ def get_task_instances( ) task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], + task_instances=task_instances, total_entries=total_entries, ) @@ -412,7 +411,7 @@ def get_task_instances_batch( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], + task_instances=[TaskInstanceResponse.model_validate(t) for t in task_instances], total_entries=total_entries, ) @@ -449,7 +448,7 @@ def _query(orm_object: Base) -> TI | TIH | None: status.HTTP_404_NOT_FOUND, f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}`, try_number: `{task_try_number}` and map_index: `{map_index}` was not found", ) - return TaskInstanceHistoryResponse.model_validate(result, from_attributes=True) + return TaskInstanceHistoryResponse.model_validate(result) @task_instances_router.get( diff --git a/airflow/api_fastapi/core_api/routes/public/tasks.py b/airflow/api_fastapi/core_api/routes/public/tasks.py index be1fdc7324d8..748a1eeebd1e 100644 --- a/airflow/api_fastapi/core_api/routes/public/tasks.py +++ b/airflow/api_fastapi/core_api/routes/public/tasks.py @@ -54,7 +54,7 @@ def get_tasks( raise HTTPException(status.HTTP_400_BAD_REQUEST, str(err)) return TaskCollectionResponse( tasks=[TaskResponse.model_validate(task, from_attributes=True) for task in tasks], - total_entries=(len(tasks)), + total_entries=len(tasks), ) @@ -76,4 +76,4 @@ def get_task(dag_id: str, task_id, request: Request) -> TaskResponse: task = dag.get_task(task_id=task_id) except TaskNotFound: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task with id {task_id} was not found") - return TaskResponse.model_validate(task, from_attributes=True) + return TaskResponse.model_validate(task) diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index 541dbcb8f107..06189f2d1c1c 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -68,7 +68,7 @@ def get_variable( status.HTTP_404_NOT_FOUND, f"The Variable with key: `{variable_key}` was not found" ) - return VariableResponse.model_validate(variable, from_attributes=True) + return VariableResponse.model_validate(variable) @variables_router.get( @@ -100,7 +100,7 @@ def get_variables( variables = session.scalars(variable_select) return VariableCollectionResponse( - variables=[VariableResponse.model_validate(variable, from_attributes=True) for variable in variables], + variables=variables, total_entries=total_entries, ) @@ -139,7 +139,7 @@ def patch_variable( data = patch_body.model_dump(exclude=non_update_fields, by_alias=True, exclude_none=True) for key, val in data.items(): setattr(variable, key, val) - return VariableResponse.model_validate(variable, from_attributes=True) + return VariableResponse.model_validate(variable) @variables_router.post( @@ -155,4 +155,4 @@ def post_variable( variable = session.scalar(select(Variable).where(Variable.key == post_body.key).limit(1)) - return VariableResponse.model_validate(variable, from_attributes=True) + return VariableResponse.model_validate(variable) diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py index ef13c927e863..dff2933940c6 100644 --- a/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -89,6 +89,6 @@ def get_xcom_entry( item = xcom_stub if stringify: - return XComResponseString.model_validate(item, from_attributes=True) + return XComResponseString.model_validate(item) - return XComResponseNative.model_validate(item, from_attributes=True) + return XComResponseNative.model_validate(item) diff --git a/airflow/api_fastapi/core_api/routes/ui/dags.py b/airflow/api_fastapi/core_api/routes/ui/dags.py index 96b8b0c1b109..017ef3c16570 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dags.py +++ b/airflow/api_fastapi/core_api/routes/ui/dags.py @@ -124,9 +124,9 @@ def recent_dag_runs( for row in dags_with_recent_dag_runs: dag_run, dag, *_ = row dag_id = dag.dag_id - dag_run_response = DAGRunResponse.model_validate(dag_run, from_attributes=True) + dag_run_response = DAGRunResponse.model_validate(dag_run) if dag_id not in dag_runs_by_dag_id: - dag_response = DAGResponse.model_validate(dag, from_attributes=True) + dag_response = DAGResponse.model_validate(dag) dag_runs_by_dag_id[dag_id] = DAGWithLatestDagRunsResponse.model_validate( { **dag_response.dict(), diff --git a/airflow/api_fastapi/core_api/routes/ui/dashboard.py b/airflow/api_fastapi/core_api/routes/ui/dashboard.py index 9462d7ee2f7c..24682fa0c17c 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dashboard.py +++ b/airflow/api_fastapi/core_api/routes/ui/dashboard.py @@ -97,4 +97,4 @@ def historical_metrics( }, } - return HistoricalMetricDataResponse.model_validate(historical_metrics_response, from_attributes=True) + return HistoricalMetricDataResponse.model_validate(historical_metrics_response) diff --git a/airflow/api_fastapi/execution_api/routes/connections.py b/airflow/api_fastapi/execution_api/routes/connections.py index 86f94f5ef3f8..ed72522ee8e2 100644 --- a/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow/api_fastapi/execution_api/routes/connections.py @@ -66,7 +66,7 @@ def get_connection( "message": f"Connection with ID {connection_id} not found", }, ) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection) def has_connection_access(connection_id: str, token: TIToken) -> bool: From 7561902d1ec327313cfe79403c6b0959b504a635 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:33:14 -0800 Subject: [PATCH 3/6] get more of them --- airflow/api_fastapi/core_api/routes/public/assets.py | 4 ++-- airflow/api_fastapi/core_api/routes/public/dag_run.py | 2 +- airflow/api_fastapi/core_api/routes/public/dags.py | 2 +- .../api_fastapi/core_api/routes/public/task_instances.py | 6 ++---- airflow/api_fastapi/core_api/routes/public/tasks.py | 2 +- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 6d3bbb8086a0..98e2217e5087 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -157,7 +157,7 @@ def get_asset_events( assets_events = session.scalars(assets_event_select) return AssetEventCollectionResponse( - asset_events=[AssetEventResponse.model_validate(asset) for asset in assets_events], + asset_events=assets_events, total_entries=total_entries, ) @@ -280,7 +280,7 @@ def get_dag_asset_queued_events( ] return QueuedEventCollectionResponse( - queued_events=[QueuedEventResponse.model_validate(queued_event) for queued_event in queued_events], + queued_events=queued_events, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 70e602ab684c..4a74d233542a 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -293,6 +293,6 @@ def get_dag_runs( ) dag_runs = session.scalars(dag_run_select) return DAGRunCollectionResponse( - dag_runs=[DAGRunResponse.model_validate(dr) for dr in dag_runs], + dag_runs=dag_runs, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 619c55b970ef..00081a057a9c 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -280,7 +280,7 @@ def patch_dags( ) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(d) for d in dags], + dags=dags, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index d87ea6a0bd48..836a61da7604 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -172,9 +172,7 @@ def get_mapped_task_instances( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(task_instance) for task_instance in task_instances - ], + task_instances=task_instances, total_entries=total_entries, ) @@ -411,7 +409,7 @@ def get_task_instances_batch( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[TaskInstanceResponse.model_validate(t) for t in task_instances], + task_instances=task_instances, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/tasks.py b/airflow/api_fastapi/core_api/routes/public/tasks.py index 748a1eeebd1e..fb8d1be8ed0f 100644 --- a/airflow/api_fastapi/core_api/routes/public/tasks.py +++ b/airflow/api_fastapi/core_api/routes/public/tasks.py @@ -53,7 +53,7 @@ def get_tasks( except AttributeError as err: raise HTTPException(status.HTTP_400_BAD_REQUEST, str(err)) return TaskCollectionResponse( - tasks=[TaskResponse.model_validate(task, from_attributes=True) for task in tasks], + tasks=[TaskResponse.model_validate(task) for task in tasks], total_entries=len(tasks), ) From b439d2e5de0d5a00188497cb6c1397eb80a93848 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:03:15 -0800 Subject: [PATCH 4/6] fix test --- airflow/api_fastapi/execution_api/datamodels/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/connection.py b/airflow/api_fastapi/execution_api/datamodels/connection.py index f3c678952982..e2641417f566 100644 --- a/airflow/api_fastapi/execution_api/datamodels/connection.py +++ b/airflow/api_fastapi/execution_api/datamodels/connection.py @@ -17,7 +17,9 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel class ConnectionResponse(BaseModel): From f835db51bbefeda8d162709962ca60ef5a790b23 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:36:00 -0800 Subject: [PATCH 5/6] Update backfill `list` endpoint to be async --- airflow/api_fastapi/common/db/common.py | 58 ++++++++++++++++++- .../core_api/routes/public/backfills.py | 15 +++-- airflow/settings.py | 10 ++-- airflow/utils/db.py | 16 +++++ airflow/utils/session.py | 18 ++++++ tests/utils/test_session.py | 4 +- 6 files changed, 103 insertions(+), 18 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 17da1eafacc9..578a8b3a9a8b 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Literal, Sequence, overload -from airflow.utils.db import get_query_count -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from sqlalchemy.ext.asyncio import AsyncSession + +from airflow.utils.db import get_query_count, get_query_count_async +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, base_select: Select, filters: Sequence[BaseParam | None] | None = None + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, ) -> Select: if filters is None: return base_select @@ -65,6 +69,54 @@ def apply_filters_to_select( return base_select +async def get_async_session() -> AsyncSession: + """ + Dependency for providing a session. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): + pass + """ + async with create_session_async() as session: + yield session + + +async def paginated_select_async( + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: bool = True, +) -> tuple[Select, int | None]: + base_select = apply_filters_to_select( + base_select=base_select, + filters=filters, + ) + + total_entries = None + if return_total_entries: + total_entries = await get_query_count_async(base_select, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + base_select = apply_filters_to_select( + base_select=base_select, + filters=[order_by, offset, limit], + ) + + return base_select, total_entries + + @overload def paginated_select( *, diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 0ce6a90839cc..3ec54a8ba0f7 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -20,9 +20,10 @@ from fastapi import Depends, HTTPException, status from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.backfills import ( @@ -49,7 +50,7 @@ @backfills_router.get( path="", ) -def list_backfills( +async def list_backfills( dag_id: str, limit: QueryLimit, offset: QueryOffset, @@ -57,18 +58,16 @@ def list_backfills( SortParam, Depends(SortParam(["id"], Backfill).dynamic_depends()), ], - session: Annotated[Session, Depends(get_session)], + session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: - select_stmt, total_entries = paginated_select( - select=select(Backfill).where(Backfill.dag_id == dag_id), + select_stmt, total_entries = await paginated_select_async( + base_select=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, session=session, ) - - backfills = session.scalars(select_stmt) - + backfills = await session.scalars(select_stmt) return BackfillCollectionResponse( backfills=backfills, total_entries=total_entries, diff --git a/airflow/settings.py b/airflow/settings.py index 5b458efcba47..76b3e948964f 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -31,7 +31,7 @@ import pluggy from packaging.version import Version from sqlalchemy import create_engine, exc, text -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool @@ -111,7 +111,7 @@ # this is achieved by the Session factory above. NonScopedSession: Callable[..., SASession] async_engine: AsyncEngine -create_async_session: Callable[..., AsyncSession] +AsyncSession: Callable[..., SAAsyncSession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine global async_engine - global create_async_session + global AsyncSession global NonScopedSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True) - create_async_session = sessionmaker( + AsyncSession = sessionmaker( bind=async_engine, autocommit=False, autoflush=False, - class_=AsyncSession, + class_=SAAsyncSession, expire_on_commit=False, ) mask_secret(engine.url.password) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d8939a117317..00d20bbbca9f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -70,6 +70,7 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> int: + """ + Get count of a query. + + A SELECT COUNT() FROM is issued against the subquery built from the + given statement. The ORDER BY clause is stripped from the statement + since it's unnecessary for COUNT, and can impact query planning and + degrade performance. + + :meta private: + """ + count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery()) + return await session.scalar(count_stmt) + + def check_query_exists(query_stmt: Select, *, session: Session) -> bool: """ Check whether there is at least one row matching a query. diff --git a/airflow/utils/session.py b/airflow/utils/session.py index a63d3f3f937a..49383cdf4a8b 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_session_async(): + """ + Context manager to create async session. + + :meta private: + """ + from airflow.settings import AsyncSession + + async with AsyncSession() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + PS = ParamSpec("PS") RT = TypeVar("RT") diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py index 02cba9e070dc..8d26a25c626a 100644 --- a/tests/utils/test_session.py +++ b/tests/utils/test_session.py @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self): @pytest.mark.asyncio async def test_async_session(self): - from airflow.settings import create_async_session + from airflow.settings import AsyncSession - session = create_async_session() + session = AsyncSession() session.add(Log(event="hihi1234")) await session.commit() my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1)) From 2939aa3d14e82123eeba3f3b39a37591f3a515e1 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:45:38 -0800 Subject: [PATCH 6/6] fix name --- airflow/api_fastapi/common/db/common.py | 40 +++++++++++++++---- .../core_api/routes/public/backfills.py | 2 +- airflow/utils/db.py | 4 +- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 578a8b3a9a8b..2d7da4bff737 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -85,9 +85,35 @@ def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): yield session +@overload async def paginated_select_async( *, - base_select: Select, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[True] = True, +) -> tuple[Select, int]: ... + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[False], +) -> tuple[Select, None]: ... + + +async def paginated_select_async( + *, + query: Select, filters: Sequence[BaseParam | None] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, @@ -95,26 +121,26 @@ async def paginated_select_async( session: AsyncSession, return_total_entries: bool = True, ) -> tuple[Select, int | None]: - base_select = apply_filters_to_select( - base_select=base_select, + query = apply_filters_to_select( + base_select=query, filters=filters, ) total_entries = None if return_total_entries: - total_entries = await get_query_count_async(base_select, session=session) + total_entries = await get_query_count_async(query, session=session) # TODO: Re-enable when permissions are handled. Readable / writable entities, # for instance: # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) - base_select = apply_filters_to_select( - base_select=base_select, + query = apply_filters_to_select( + base_select=query, filters=[order_by, offset, limit], ) - return base_select, total_entries + return query, total_entries @overload diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 3ec54a8ba0f7..2977685a6950 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -61,7 +61,7 @@ async def list_backfills( session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: select_stmt, total_entries = await paginated_select_async( - base_select=select(Backfill).where(Backfill.dag_id == dag_id), + query=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 00d20bbbca9f..c899ebf615d0 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -1448,7 +1448,7 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) -async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> int: +async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: """ Get count of a query. @@ -1459,7 +1459,7 @@ async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> :meta private: """ - count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery()) + count_stmt = select(func.count()).select_from(query.order_by(None).subquery()) return await session.scalar(count_stmt)