diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 572189651b3e7..63d0b0c5c6095 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -424,6 +424,7 @@ repos: (?x) ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| + ^airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| ^airflow-core/tests/unit/models/test_serialized_dag.py$| ^airflow-core/tests/unit/models/test_pool.py$| @@ -439,7 +440,10 @@ repos: ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py$| ^airflow-core/tests/unit/cli/commands/test_task_command.py$| ^airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py$| + ^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py$| ^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py$| + ^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py$| + ^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py$| ^airflow-core/tests/unit/models/test_deadline.py$| ^airflow-core/tests/unit/models/test_renderedtifields.py$| ^airflow-core/tests/unit/models/test_timestamp.py$| diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index eb808fb6cbc4b..514516e68d35a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -248,11 +248,7 @@ def ti_run( xcom_keys = list(session.scalars(xcom_query)) task_reschedule_count = ( - session.query( - func.count(TaskReschedule.id) # or any other primary key column - ) - .filter(TaskReschedule.ti_id == ti_id_str) - .scalar() + session.scalar(select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str)) or 0 ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py index 07ea2e8f302f9..91c9c314a0e6d 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py @@ -19,6 +19,7 @@ import pytest import time_machine +from sqlalchemy import select, update from airflow._shared.timezones import timezone from airflow.models import DagModel @@ -56,7 +57,7 @@ def test_trigger_dag_run(self, client, session, dag_maker): assert response.status_code == 204 - dag_run = session.query(DagRun).filter(DagRun.run_id == run_id).one() + dag_run = session.scalars(select(DagRun).where(DagRun.run_id == run_id)).one() assert dag_run.conf == {"key1": "value1"} assert dag_run.logical_date == logical_date @@ -81,7 +82,7 @@ def test_trigger_dag_run_import_error(self, client, session, dag_maker): with dag_maker(dag_id=dag_id, session=session, serialized=True): EmptyOperator(task_id="test_task") - session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"has_import_errors": True}) + session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(has_import_errors=True)) session.commit() @@ -160,7 +161,7 @@ def test_dag_run_clear(self, client, session, dag_maker): assert response.status_code == 204 session.expire_all() - dag_run = session.query(DagRun).filter(DagRun.run_id == run_id).one() + dag_run = session.scalars(select(DagRun).where(DagRun.run_id == run_id)).one() assert dag_run.state == DagRunState.QUEUED def test_dag_run_import_error(self, client, session, dag_maker): @@ -172,7 +173,7 @@ def test_dag_run_import_error(self, client, session, dag_maker): with dag_maker(dag_id=dag_id, session=session, serialized=True): EmptyOperator(task_id="test_task") - session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"has_import_errors": True}) + session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(has_import_errors=True)) session.commit() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index 5066454ede9c8..59b206441dea6 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -23,6 +23,7 @@ import pytest from fastapi import FastAPI, HTTPException, Request, status from fastapi.routing import Mount +from sqlalchemy import select from airflow.models.variable import Variable @@ -150,7 +151,7 @@ def test_should_create_variable(self, client, key, payload, session): assert response.status_code == 201, response.json() assert response.json()["message"] == "Variable successfully set" - var_from_db = session.query(Variable).where(Variable.key == key).first() + var_from_db = session.scalars(select(Variable).where(Variable.key == key)).first() assert var_from_db is not None assert var_from_db.key == key assert var_from_db.val == payload["value"] @@ -216,7 +217,7 @@ def test_overwriting_existing_variable(self, client, session, key): assert response.status_code == 201 assert response.json()["message"] == "Variable successfully set" # variable should have been updated to the new value - var_from_db = session.query(Variable).where(Variable.key == key).first() + var_from_db = session.scalars(select(Variable).where(Variable.key == key)).first() assert var_from_db is not None assert var_from_db.key == key assert var_from_db.val == payload["value"] @@ -253,25 +254,25 @@ def test_should_delete_variable(self, client, session, keys_to_create, key_to_de for i, key in enumerate(keys_to_create, 1): Variable.set(key=key, value=str(i)) - vars = session.query(Variable).all() + vars = session.scalars(select(Variable)).all() assert len(vars) == len(keys_to_create) response = client.delete(f"/execution/variables/{key_to_delete}") assert response.status_code == 204 - vars = session.query(Variable).all() + vars = session.scalars(select(Variable)).all() assert len(vars) == len(keys_to_create) - 1 def test_should_not_delete_variable(self, client, session): Variable.set(key="key", value="value") - vars = session.query(Variable).all() + vars = session.scalars(select(Variable)).all() assert len(vars) == 1 response = client.delete("/execution/variables/non_existent_key") assert response.status_code == 204 - vars = session.query(Variable).all() + vars = session.scalars(select(Variable)).all() assert len(vars) == 1 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index a4be2c544506c..f805971bf5207 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -24,6 +24,7 @@ import httpx import pytest from fastapi import FastAPI, HTTPException, Path, Request, status +from sqlalchemy import delete, select from airflow._shared.timezones import timezone from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse @@ -41,8 +42,8 @@ def reset_db(): """Reset XCom entries.""" with create_session() as session: - session.query(DagRun).delete() - session.query(XComModel).delete() + session.execute(delete(DagRun)) + session.execute(delete(XComModel)) @pytest.fixture @@ -354,9 +355,17 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v assert response.status_code == 201 assert response.json() == {"message": "XCom successfully set"} - xcom = session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + xcom = session.scalars( + select(XComModel).where( + XComModel.task_id == ti.task_id, + XComModel.dag_id == ti.dag_id, + XComModel.key == "xcom_1", + ) + ).first() assert xcom.value == expected_value - task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() + task_map = session.scalars( + select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id) + ).one_or_none() assert task_map is None, "Should not be mapped" @pytest.mark.parametrize( @@ -438,13 +447,18 @@ def test_xcom_set_mapped(self, client, create_task_instance, session): assert response.status_code == 201 assert response.json() == {"message": "XCom successfully set"} - xcom = ( - session.query(XComModel) - .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1", map_index=-1) - .first() - ) + xcom = session.scalars( + select(XComModel).where( + XComModel.task_id == ti.task_id, + XComModel.dag_id == ti.dag_id, + XComModel.key == "xcom_1", + XComModel.map_index == -1, + ) + ).first() assert xcom.value == "value1" - task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() + task_map = session.scalars( + select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id) + ).one_or_none() assert task_map is not None, "Should be mapped" assert task_map.dag_id == "dag" assert task_map.run_id == "test" @@ -484,7 +498,9 @@ def test_xcom_set_downstream_of_mapped(self, client, create_task_instance, sessi ) response.raise_for_status() - task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none() + task_map = session.scalars( + select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id) + ).one_or_none() assert task_map.length == length @pytest.mark.usefixtures("access_denied") @@ -530,11 +546,13 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe json=value, ) - xcom = ( - session.query(XComModel) - .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="test_xcom_roundtrip") - .first() - ) + xcom = session.scalars( + select(XComModel).where( + XComModel.task_id == ti.task_id, + XComModel.dag_id == ti.dag_id, + XComModel.key == "test_xcom_roundtrip", + ) + ).first() assert xcom.value == expected_value response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip") @@ -553,7 +571,7 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session): ti1.xcom_push(key="xcom_1", value='"value2"', session=session) session.commit() - xcoms = session.query(XComModel).filter_by(key="xcom_1").all() + xcoms = session.scalars(select(XComModel).where(XComModel.key == "xcom_1")).all() assert xcoms is not None assert len(xcoms) == 2 @@ -562,12 +580,20 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session): assert response.status_code == 200 assert response.json() == {"message": "XCom with key: xcom_1 successfully deleted."} - xcom_ti = ( - session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() - ) + xcom_ti = session.scalars( + select(XComModel).where( + XComModel.task_id == ti.task_id, + XComModel.dag_id == ti.dag_id, + XComModel.key == "xcom_1", + ) + ).first() assert xcom_ti is None - xcom_ti = ( - session.query(XComModel).filter_by(task_id=ti1.task_id, dag_id=ti1.dag_id, key="xcom_1").first() - ) + xcom_ti = session.scalars( + select(XComModel).where( + XComModel.task_id == ti1.task_id, + XComModel.dag_id == ti1.dag_id, + XComModel.key == "xcom_1", + ) + ).first() assert xcom_ti is not None