Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ repos:
(?x)
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^task_sdk.*\.py$
pass_filenames: true
- id: update-supported-versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
import time_machine
from sqlalchemy import delete, func, select

from airflow._shared.timezones import timezone
from airflow.models import DagModel
Expand Down Expand Up @@ -224,8 +225,8 @@ def _create_dag_run(session, num: int = 2):

def _create_asset_dag_run(session, num: int = 2):
for i in range(1, 1 + num):
dag_run = session.query(DagRun).filter_by(run_id=f"source_run_id_{i}").first()
asset_event = session.query(AssetEvent).filter_by(id=i).first()
dag_run = session.scalar(select(DagRun).where(DagRun.run_id == f"source_run_id_{i}"))
asset_event = session.scalar(select(AssetEvent).where(AssetEvent.id == i))
if dag_run and asset_event:
dag_run.consumed_asset_events.append(asset_event)
session.commit()
Expand Down Expand Up @@ -300,8 +301,8 @@ def test_should_respond_200(self, test_client, session):
session.add(AssetModel("inactive", "inactive"))
session.commit()

assert len(session.query(AssetModel).all()) == 3
assert len(session.query(AssetActive).all()) == 2
assert len(session.scalars(select(AssetModel)).all()) == 3
assert len(session.scalars(select(AssetActive)).all()) == 2

with assert_queries_count(7):
response = test_client.get("/assets")
Expand Down Expand Up @@ -417,8 +418,8 @@ def test_should_show_inactive(self, test_client, session):
)
session.commit()

assert len(session.query(AssetModel).all()) == 3
assert len(session.query(AssetActive).all()) == 2
assert len(session.scalars(select(AssetModel)).all()) == 3
assert len(session.scalars(select(AssetActive)).all()) == 2

response = test_client.get("/assets?only_active=0")
assert response.status_code == 200
Expand Down Expand Up @@ -567,7 +568,7 @@ def test_filter_assets_by_uri_pattern_works(self, test_client, params, expected_
def test_filter_assets_by_dag_ids_works(
self, test_client, dag_ids, expected_num, testing_dag_bundle, session
):
session.query(DagModel).delete()
session.execute(delete(DagModel))
session.commit()
bundle_name = "testing"

Expand Down Expand Up @@ -606,7 +607,7 @@ def test_filter_assets_by_dag_ids_works(
def test_filter_assets_by_dag_ids_and_uri_pattern_works(
self, test_client, dag_ids, uri_pattern, expected_num, testing_dag_bundle, session
):
session.query(DagModel).delete()
session.execute(delete(DagModel))
session.commit()
bundle_name = "testing"

Expand Down Expand Up @@ -697,7 +698,7 @@ def create_provided_asset_alias(self, asset_alias: AssetAliasModel, session):
class TestGetAssetAliases(TestAssetAliases):
def test_should_respond_200(self, test_client, session):
self.create_asset_aliases()
asset_aliases = session.query(AssetAliasModel).all()
asset_aliases = session.scalars(select(AssetAliasModel)).all()
assert len(asset_aliases) == 2

with assert_queries_count(2):
Expand Down Expand Up @@ -782,7 +783,7 @@ def test_should_respond_200(self, test_client, session):
self.create_assets_events(session)
self.create_dag_run(session)
self.create_asset_dag_run(session)
assets = session.query(AssetEvent).all()
assets = session.scalars(select(AssetEvent)).all()
session.commit()
assert len(assets) == 2

Expand Down Expand Up @@ -1039,7 +1040,7 @@ class TestGetAssetEndpoint(TestAssets):
@provide_session
def test_should_respond_200(self, test_client, session):
self.create_assets(num=1)
assert session.query(AssetModel).count() == 1
assert session.scalars(select(func.count(AssetModel.id))).one() == 1
tz_datetime_format = from_datetime_to_zulu_without_ms(DEFAULT_DATE)
with assert_queries_count(6):
response = test_client.get("/assets/1")
Expand Down Expand Up @@ -1134,7 +1135,7 @@ class TestGetAssetAliasEndpoint(TestAssetAliases):
@provide_session
def test_should_respond_200(self, test_client, session):
self.create_asset_aliases(num=1)
assert session.query(AssetAliasModel).count() == 1
assert session.scalars(select(func.count(AssetAliasModel.id))).one() == 1
with assert_queries_count(6):
response = test_client.get("/assets/aliases/1")
assert response.status_code == 200
Expand All @@ -1148,7 +1149,7 @@ def test_should_respond_404(self, test_client):

class TestQueuedEventEndpoint(TestAssets):
def _create_asset_dag_run_queues(self, dag_id, asset_id, session):
session.query(AssetDagRunQueue).delete()
session.execute(delete(AssetDagRunQueue))
session.flush()
adrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id)
session.add(adrq)
Expand Down Expand Up @@ -1209,15 +1210,15 @@ def test_should_respond_204(self, test_client, session, create_dummy_dag):
self.create_assets(session=session, num=1)
asset_id = 1
self._create_asset_dag_run_queues(dag_id, asset_id, session)
adrqs = session.query(AssetDagRunQueue).all()
adrqs = session.scalars(select(AssetDagRunQueue)).all()
assert len(adrqs) == 1

response = test_client.delete(
f"/dags/{dag_id}/assets/queuedEvents",
)

assert response.status_code == 204
adrqs = session.query(AssetDagRunQueue).all()
adrqs = session.scalars(select(AssetDagRunQueue)).all()
assert len(adrqs) == 0
check_last_log(session, dag_id=dag_id, event="delete_dag_asset_queued_events", logical_date=None)

Expand All @@ -1243,7 +1244,7 @@ def test_should_respond_404_valid_dag_no_adrq(self, test_client, session, create
dag, _ = create_dummy_dag()
dag_id = dag.dag_id
self.create_assets(session=session, num=1)
adrqs = session.query(AssetDagRunQueue).all()
adrqs = session.scalars(select(AssetDagRunQueue)).all()
assert len(adrqs) == 0

response = test_client.delete(
Expand Down Expand Up @@ -1493,15 +1494,15 @@ def test_delete_should_respond_204(self, test_client, session, create_dummy_dag)
(asset,) = self.create_assets(session=session, num=1)

self._create_asset_dag_run_queues(dag_id, asset.id, session)
adrq = session.query(AssetDagRunQueue).all()
adrq = session.scalars(select(AssetDagRunQueue)).all()
assert len(adrq) == 1

response = test_client.delete(
f"/dags/{dag_id}/assets/{asset.id}/queuedEvents",
)

assert response.status_code == 204
adrq = session.query(AssetDagRunQueue).all()
adrq = session.scalars(select(AssetDagRunQueue)).all()
assert len(adrq) == 0
check_last_log(session, dag_id=dag_id, event="delete_dag_asset_queued_event", logical_date=None)

Expand Down
Loading