Skip to content

Commit

Permalink
fix(datasets/managers): fix error handling file loc when dataset ali…
Browse files Browse the repository at this point in the history
…as resolved into new datasets (#42733)

* fix(datasets/managers): fix error handling fileloc when datasetalias resolved into new datasets

* test(datasets/manager): add test case test_register_dataset_change_with_alias

* refactor(datasets/manager): simplify for loop

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
Lee-W and uranusjr authored Oct 8, 2024
1 parent b9069e7 commit 35264c1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
5 changes: 3 additions & 2 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,15 @@ def _send_dag_priority_parsing_request_if_needed(fileloc: str) -> str | None:
return None
return req.fileloc

(_send_dag_priority_parsing_request_if_needed(fileloc) for fileloc in file_locs)
for fileloc in file_locs:
_send_dag_priority_parsing_request_if_needed(fileloc)

@classmethod
def _postgres_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], session: Session) -> None:
from sqlalchemy.dialects.postgresql import insert

stmt = insert(DagPriorityParsingRequest).on_conflict_do_nothing()
session.execute(stmt, {"fileloc": fileloc for fileloc in file_locs})
session.execute(stmt, [{"fileloc": fileloc} for fileloc in file_locs])


def resolve_dataset_manager() -> DatasetManager:
Expand Down
55 changes: 53 additions & 2 deletions tests/datasets/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@
import pytest
from sqlalchemy import delete

from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias
from airflow.datasets.manager import DatasetManager
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DagModel
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dataset import (
DagScheduleDatasetAliasReference,
DagScheduleDatasetReference,
DatasetAliasModel,
DatasetDagRunQueue,
DatasetEvent,
DatasetModel,
)
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from tests.listeners import dataset_listener

Expand All @@ -38,6 +46,15 @@
pytest.importorskip("pydantic", minversion="2.0.0")


@pytest.fixture
def clear_datasets():
from tests.test_utils.db import clear_db_datasets

clear_db_datasets()
yield
clear_db_datasets()


@pytest.fixture
def mock_task_instance():
return TaskInstancePydantic(
Expand Down Expand Up @@ -127,6 +144,40 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance):
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 2

@pytest.mark.usefixtures("clear_datasets")
def test_register_dataset_change_with_alias(self, session, dag_maker, mock_task_instance):
consumer_dag_1 = DagModel(dag_id="conumser_1", is_active=True, fileloc="dag1.py")
consumer_dag_2 = DagModel(dag_id="conumser_2", is_active=True, fileloc="dag2.py")
session.add_all([consumer_dag_1, consumer_dag_2])

dsm = DatasetModel(uri="test_dataset_uri")
session.add(dsm)

dsam = DatasetAliasModel(name="test_dataset_name")
session.add(dsam)
dsam.consuming_dags = [
DagScheduleDatasetAliasReference(dag_id=dag.dag_id) for dag in (consumer_dag_1, consumer_dag_2)
]
session.execute(delete(DatasetDagRunQueue))
session.flush()

dataset = Dataset(uri="test_dataset_uri")
dataset_alias = DatasetAlias(name="test_dataset_name")
dataset_manager = DatasetManager()
dataset_manager.register_dataset_change(
task_instance=mock_task_instance,
dataset=dataset,
aliases=[dataset_alias],
source_alias_names=["test_dataset_name"],
session=session,
)
session.flush()

# Ensure we've created an asset
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 2
assert session.query(DagPriorityParsingRequest).count() == 2

def test_register_dataset_change_no_downstreams(self, session, mock_task_instance):
dsem = DatasetManager()

Expand Down

0 comments on commit 35264c1

Please sign in to comment.