Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_serialized_dag.py$|
^airflow-core/tests/unit/utils/test_db_cleanup.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/openlineage/.*\.py$|
Expand Down
45 changes: 26 additions & 19 deletions airflow-core/tests/unit/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pendulum
import pytest
from sqlalchemy import func, select, update
from sqlalchemy import delete, func, select, update

import airflow.example_dags as example_dags_module
from airflow.dag_processing.dagbag import DagBag
Expand Down Expand Up @@ -59,7 +59,12 @@ def make_example_dags(module):
from airflow.utils.session import create_session

with create_session() as session:
if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0:
if (
session.scalar(
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing")
)
== 0
):
testing = DagBundleModel(name="testing")
session.add(testing)

Expand Down Expand Up @@ -101,7 +106,7 @@ def test_write_dag(self, testing_dag_bundle):
with create_session() as session:
for dag in example_dags.values():
assert SDM.has_dag(dag.dag_id)
result = session.query(SDM).filter(SDM.dag_id == dag.dag_id).one()
result = session.scalar(select(SDM).where(SDM.dag_id == dag.dag_id))

assert result.dag_version.dag_code.fileloc == dag.fileloc
# Verifies JSON schema.
Expand All @@ -118,7 +123,7 @@ def my_callable():
with dag_maker("dag1"):
PythonOperator(task_id="task1", python_callable=lambda x: None)
dag_maker.create_dagrun(run_id="test2", logical_date=pendulum.datetime(2025, 1, 1))
assert len(session.query(DagVersion).all()) == 2
assert len(session.scalars(select(DagVersion)).all()) == 2

with dag_maker("dag2"):

Expand All @@ -136,7 +141,7 @@ def my_callable2():
pass

my_callable2()
assert len(session.query(DagVersion).all()) == 4
assert len(session.scalars(select(DagVersion)).all()) == 4

def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle):
"""Test Serialized DAG is updated if DAG is changed"""
Expand Down Expand Up @@ -212,7 +217,7 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session):
dag.doc_md = "new doc string"
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing")
serialized_dags2 = SDM.read_all_dags()
sdags = session.query(SDM).all()
sdags = session.scalars(select(SDM)).all()
# assert only the latest SDM is returned
assert len(sdags) != len(serialized_dags2)

Expand Down Expand Up @@ -334,7 +339,7 @@ def test_get_latest_serdag_versions(self, dag_maker, session):
def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session):
with dag_maker("dag1") as dag:
PythonOperator(task_id="task1", python_callable=lambda: None)
assert session.query(SDM).count() == 1
assert session.scalar(select(func.count()).select_from(SDM)) == 1
sdm1 = SDM.get(dag.dag_id, session=session)
dag_hash = sdm1.dag_hash
created_at = sdm1.created_at
Expand All @@ -347,21 +352,21 @@ def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session
assert sdm2.dag_hash != dag_hash # first recorded serdag
assert sdm2.created_at == created_at
assert sdm2.last_updated != last_updated
assert session.query(DagVersion).count() == 1
assert session.query(SDM).count() == 1
assert session.scalar(select(func.count()).select_from(DagVersion)) == 1
assert session.scalar(select(func.count()).select_from(SDM)) == 1

def test_new_dag_versions_are_created_if_there_is_a_dagrun(self, dag_maker, session):
with dag_maker("dag1") as dag:
PythonOperator(task_id="task1", python_callable=lambda: None)
dag_maker.create_dagrun(run_id="test3", logical_date=pendulum.datetime(2025, 1, 2))
assert session.query(SDM).count() == 1
assert session.query(DagVersion).count() == 1
assert session.scalar(select(func.count()).select_from(SDM)) == 1
assert session.scalar(select(func.count()).select_from(DagVersion)) == 1
# new task
PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker")

assert session.query(DagVersion).count() == 2
assert session.query(SDM).count() == 2
assert session.scalar(select(func.count()).select_from(DagVersion)) == 2
assert session.scalar(select(func.count()).select_from(SDM)) == 2

def test_example_dag_sorting_serialised_dag(self, session):
"""
Expand Down Expand Up @@ -517,14 +522,14 @@ def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(sel
# Create TIs
dag_maker.create_dagrun(run_id="test_run")

assert session.query(DagVersion).count() == 1
assert session.scalar(select(func.count()).select_from(DagVersion)) == 1

# Write the same DAG (no changes, so hash is the same) with a new bundle_name
new_bundle = "bundleB"
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=new_bundle)

# There should now be two versions of the DAG
assert session.query(DagVersion).count() == 2
assert session.scalar(select(func.count()).select_from(DagVersion)) == 2

def test_hash_method_removes_fileloc_and_remains_consistent(self):
"""Test that the hash method removes fileloc before hashing."""
Expand Down Expand Up @@ -632,7 +637,7 @@ def test_dynamic_dag_update_preserves_null_check(self, dag_maker, session):
assert dag_version is not None

# Manually delete SerializedDagModel (simulates edge case)
session.query(SDM).filter(SDM.dag_id == "test_missing_serdag").delete()
session.execute(delete(SDM).where(SDM.dag_id == "test_missing_serdag"))
session.commit()

# Verify no SerializedDagModel exists
Expand Down Expand Up @@ -709,7 +714,9 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session):
EmptyOperator(task_id="task1")

dag = dag_maker.dag
initial_version_count = session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count()
initial_version_count = session.scalar(
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id)
)
assert initial_version_count == 1, "Should have one DagVersion after initial write"
dag_maker.create_dagrun() # ensure the second dag version is created

Expand All @@ -732,8 +739,8 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session):
# Verify that no new DagVersion was committed
# Use a fresh session to ensure we're reading from committed data
with create_session() as fresh_session:
final_version_count = (
fresh_session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count()
final_version_count = fresh_session.scalar(
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id)
)
assert final_version_count == initial_version_count, (
"DagVersion should not be committed when DagCode.write_code fails"
Expand Down