Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ repos:
(?x)
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_serialized_dag.py$|
^airflow-core/tests/unit/models/test_pool.py$|
Expand All @@ -439,7 +440,10 @@ repos:
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py$|
^airflow-core/tests/unit/cli/commands/test_task_command.py$|
^airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py$|
^airflow-core/tests/unit/models/test_deadline.py$|
^airflow-core/tests/unit/models/test_renderedtifields.py$|
^airflow-core/tests/unit/models/test_timestamp.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,7 @@ def ti_run(

xcom_keys = list(session.scalars(xcom_query))
task_reschedule_count = (
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
.filter(TaskReschedule.ti_id == ti_id_str)
.scalar()
session.scalar(select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str))
or 0
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest
import time_machine
from sqlalchemy import select, update

from airflow._shared.timezones import timezone
from airflow.models import DagModel
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_trigger_dag_run(self, client, session, dag_maker):

assert response.status_code == 204

dag_run = session.query(DagRun).filter(DagRun.run_id == run_id).one()
dag_run = session.scalars(select(DagRun).where(DagRun.run_id == run_id)).one()
assert dag_run.conf == {"key1": "value1"}
assert dag_run.logical_date == logical_date

Expand All @@ -81,7 +82,7 @@ def test_trigger_dag_run_import_error(self, client, session, dag_maker):
with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"has_import_errors": True})
session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(has_import_errors=True))

session.commit()

Expand Down Expand Up @@ -160,7 +161,7 @@ def test_dag_run_clear(self, client, session, dag_maker):
assert response.status_code == 204

session.expire_all()
dag_run = session.query(DagRun).filter(DagRun.run_id == run_id).one()
dag_run = session.scalars(select(DagRun).where(DagRun.run_id == run_id)).one()
assert dag_run.state == DagRunState.QUEUED

def test_dag_run_import_error(self, client, session, dag_maker):
Expand All @@ -172,7 +173,7 @@ def test_dag_run_import_error(self, client, session, dag_maker):
with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"has_import_errors": True})
session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(has_import_errors=True))

session.commit()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.routing import Mount
from sqlalchemy import select

from airflow.models.variable import Variable

Expand Down Expand Up @@ -150,7 +151,7 @@ def test_should_create_variable(self, client, key, payload, session):
assert response.status_code == 201, response.json()
assert response.json()["message"] == "Variable successfully set"

var_from_db = session.query(Variable).where(Variable.key == key).first()
var_from_db = session.scalars(select(Variable).where(Variable.key == key)).first()
assert var_from_db is not None
assert var_from_db.key == key
assert var_from_db.val == payload["value"]
Expand Down Expand Up @@ -216,7 +217,7 @@ def test_overwriting_existing_variable(self, client, session, key):
assert response.status_code == 201
assert response.json()["message"] == "Variable successfully set"
# variable should have been updated to the new value
var_from_db = session.query(Variable).where(Variable.key == key).first()
var_from_db = session.scalars(select(Variable).where(Variable.key == key)).first()
assert var_from_db is not None
assert var_from_db.key == key
assert var_from_db.val == payload["value"]
Expand Down Expand Up @@ -253,25 +254,25 @@ def test_should_delete_variable(self, client, session, keys_to_create, key_to_de
for i, key in enumerate(keys_to_create, 1):
Variable.set(key=key, value=str(i))

vars = session.query(Variable).all()
vars = session.scalars(select(Variable)).all()
assert len(vars) == len(keys_to_create)

response = client.delete(f"/execution/variables/{key_to_delete}")

assert response.status_code == 204

vars = session.query(Variable).all()
vars = session.scalars(select(Variable)).all()
assert len(vars) == len(keys_to_create) - 1

def test_should_not_delete_variable(self, client, session):
Variable.set(key="key", value="value")

vars = session.query(Variable).all()
vars = session.scalars(select(Variable)).all()
assert len(vars) == 1

response = client.delete("/execution/variables/non_existent_key")

assert response.status_code == 204

vars = session.query(Variable).all()
vars = session.scalars(select(Variable)).all()
assert len(vars) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import httpx
import pytest
from fastapi import FastAPI, HTTPException, Path, Request, status
from sqlalchemy import delete, select

from airflow._shared.timezones import timezone
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
Expand All @@ -41,8 +42,8 @@
def reset_db():
"""Reset XCom entries."""
with create_session() as session:
session.query(DagRun).delete()
session.query(XComModel).delete()
session.execute(delete(DagRun))
session.execute(delete(XComModel))


@pytest.fixture
Expand Down Expand Up @@ -354,9 +355,17 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v
assert response.status_code == 201
assert response.json() == {"message": "XCom successfully set"}

xcom = session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first()
xcom = session.scalars(
select(XComModel).where(
XComModel.task_id == ti.task_id,
XComModel.dag_id == ti.dag_id,
XComModel.key == "xcom_1",
)
).first()
assert xcom.value == expected_value
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
task_map = session.scalars(
select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id)
).one_or_none()
assert task_map is None, "Should not be mapped"

@pytest.mark.parametrize(
Expand Down Expand Up @@ -438,13 +447,18 @@ def test_xcom_set_mapped(self, client, create_task_instance, session):
assert response.status_code == 201
assert response.json() == {"message": "XCom successfully set"}

xcom = (
session.query(XComModel)
.filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1", map_index=-1)
.first()
)
xcom = session.scalars(
select(XComModel).where(
XComModel.task_id == ti.task_id,
XComModel.dag_id == ti.dag_id,
XComModel.key == "xcom_1",
XComModel.map_index == -1,
)
).first()
assert xcom.value == "value1"
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
task_map = session.scalars(
select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id)
).one_or_none()
assert task_map is not None, "Should be mapped"
assert task_map.dag_id == "dag"
assert task_map.run_id == "test"
Expand Down Expand Up @@ -484,7 +498,9 @@ def test_xcom_set_downstream_of_mapped(self, client, create_task_instance, sessi
)
response.raise_for_status()

task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
task_map = session.scalars(
select(TaskMap).where(TaskMap.task_id == ti.task_id, TaskMap.dag_id == ti.dag_id)
).one_or_none()
assert task_map.length == length

@pytest.mark.usefixtures("access_denied")
Expand Down Expand Up @@ -530,11 +546,13 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe
json=value,
)

xcom = (
session.query(XComModel)
.filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="test_xcom_roundtrip")
.first()
)
xcom = session.scalars(
select(XComModel).where(
XComModel.task_id == ti.task_id,
XComModel.dag_id == ti.dag_id,
XComModel.key == "test_xcom_roundtrip",
)
).first()
assert xcom.value == expected_value

response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip")
Expand All @@ -553,7 +571,7 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session):
ti1.xcom_push(key="xcom_1", value='"value2"', session=session)
session.commit()

xcoms = session.query(XComModel).filter_by(key="xcom_1").all()
xcoms = session.scalars(select(XComModel).where(XComModel.key == "xcom_1")).all()
assert xcoms is not None
assert len(xcoms) == 2

Expand All @@ -562,12 +580,20 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session):
assert response.status_code == 200
assert response.json() == {"message": "XCom with key: xcom_1 successfully deleted."}

xcom_ti = (
session.query(XComModel).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first()
)
xcom_ti = session.scalars(
select(XComModel).where(
XComModel.task_id == ti.task_id,
XComModel.dag_id == ti.dag_id,
XComModel.key == "xcom_1",
)
).first()
assert xcom_ti is None

xcom_ti = (
session.query(XComModel).filter_by(task_id=ti1.task_id, dag_id=ti1.dag_id, key="xcom_1").first()
)
xcom_ti = session.scalars(
select(XComModel).where(
XComModel.task_id == ti1.task_id,
XComModel.dag_id == ti1.dag_id,
XComModel.key == "xcom_1",
)
).first()
assert xcom_ti is not None
Loading