diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py index a84f36ab010d3..738e4c6edf65f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py @@ -122,7 +122,15 @@ def structure_data( elif ( dependency.target == dependency.dependency_type or dependency.source == dag_id ) and exit_node_ref: - end_edges.append({"source_id": exit_node_ref["id"], "target_id": dependency.node_id}) + end_edges.append( + { + "source_id": exit_node_ref["id"], + "target_id": dependency.node_id, + "resolved_from_alias": dependency.source.replace("asset-alias:", "", 1) + if dependency.source.startswith("asset-alias:") + else None, + } + ) # Add nodes nodes.append( @@ -142,6 +150,6 @@ def structure_data( data["edges"] += start_edges + end_edges - bind_output_assets_to_tasks(data["edges"], serialized_dag) + bind_output_assets_to_tasks(data["edges"], serialized_dag, version_number, session) return StructureDataResponse(**data) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/structure.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/structure.py index 6f5f415d3fdb7..db3d1ba6deac4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/structure.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/structure.py @@ -23,6 +23,14 @@ from __future__ import annotations +from collections import defaultdict + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from airflow.models.asset import AssetAliasModel, AssetEvent +from airflow.models.dag_version import DagVersion +from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel @@ -116,30 +124,62 @@ def get_upstream_assets( return nodes, edges -def bind_output_assets_to_tasks(edges: list[dict], serialized_dag: SerializedDagModel) -> None: +def bind_output_assets_to_tasks( + edges: list[dict], serialized_dag: SerializedDagModel, version_number: int, session: Session +) -> None: """ Try to bind the downstream assets to the relevant task that produces them. This function will mutate the `edges` in place. """ + # bind normal assets present in the `task_outlet_asset_references` outlet_asset_references = serialized_dag.dag_model.task_outlet_asset_references - downstream_asset_related_edges = [edge for edge in edges if edge["target_id"].startswith("asset:")] - - for edge in downstream_asset_related_edges: - asset_id = int(edge["target_id"].strip("asset:")) - try: - # Try to attach the outlet asset to the relevant task - outlet_asset_reference = next( - outlet_asset_reference - for outlet_asset_reference in outlet_asset_references - if outlet_asset_reference.asset_id == asset_id - ) - edge["source_id"] = outlet_asset_reference.task_id - continue - except StopIteration: - # If no asset reference found, fallback to using the exit node reference - # This can happen because asset aliases are not yet handled, they do no populate - # the `outlet_asset_references` when resolved. Extra lookup is needed. Same for asset-name-ref and - # asset-uri-ref. - pass + downstream_asset_edges = [ + edge + for edge in edges + if edge["target_id"].startswith("asset:") and not edge.get("resolved_from_alias") + ] + + for edge in downstream_asset_edges: + # Try to attach the outlet assets to the relevant tasks + asset_id = int(edge["target_id"].replace("asset:", "", 1)) + outlet_asset_reference = next( + outlet_asset_reference + for outlet_asset_reference in outlet_asset_references + if outlet_asset_reference.asset_id == asset_id + ) + edge["source_id"] = outlet_asset_reference.task_id + + # bind assets resolved from aliases, they do not populate the `outlet_asset_references` + downstream_alias_resolved_edges = [ + edge for edge in edges if edge["target_id"].startswith("asset:") and edge.get("resolved_from_alias") + ] + + aliases_names = {edges["resolved_from_alias"] for edges in downstream_alias_resolved_edges} + + result = session.scalars( + select(AssetEvent) + .join(AssetEvent.source_aliases) + .join(AssetEvent.source_dag_run) + # That's a simplification, instead doing `version_number` in `DagRun.dag_versions`. + .join(DagRun.created_dag_version) + .where(AssetEvent.source_aliases.any(AssetAliasModel.name.in_(aliases_names))) + .where(AssetEvent.source_dag_run.has(DagRun.dag_id == serialized_dag.dag_model.dag_id)) + .where(DagVersion.version_number == version_number) + ).unique() + + asset_id_to_task_ids = defaultdict(set) + for asset_event in result: + asset_id_to_task_ids[asset_event.asset_id].add(asset_event.source_task_id) + + for edge in downstream_alias_resolved_edges: + asset_id = int(edge["target_id"].replace("asset:", "", 1)) + task_ids = asset_id_to_task_ids.get(asset_id, set()) + + for index, task_id in enumerate(task_ids): + if index == 0: + edge["source_id"] = task_id + continue + edge_copy = {**edge, "source_id": task_id} + edges.append(edge_copy) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py index 2c6425db1b395..1587797256a7a 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py @@ -25,18 +25,21 @@ from sqlalchemy.orm import Session from airflow.models import DagBag -from airflow.models.asset import AssetModel +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers.standard.sensors.external_task import ExternalTaskSensor +from airflow.sdk import Metadata, task from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset +from airflow.utils import timezone -from tests_common.test_utils.db import clear_db_runs +from tests_common.test_utils.db import clear_db_assets, clear_db_runs pytestmark = pytest.mark.db_test DAG_ID = "dag_with_multiple_versions" DAG_ID_EXTERNAL_TRIGGER = "external_trigger" +DAG_ID_RESOLVED_ASSET_ALIAS = "dag_with_resolved_asset_alias" LATEST_VERSION_DAG_RESPONSE: dict = { "edges": [], "nodes": [ @@ -95,8 +98,10 @@ def examples_dag_bag() -> DagBag: @pytest.fixture(autouse=True) def clean(): clear_db_runs() + clear_db_assets() yield clear_db_runs() + clear_db_assets() @pytest.fixture @@ -115,7 +120,7 @@ def asset3() -> Dataset: @pytest.fixture -def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None: +def make_dags(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, asset3: Dataset) -> None: with dag_maker( dag_id=DAG_ID_EXTERNAL_TRIGGER, serialized=True, @@ -123,7 +128,6 @@ def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, ass start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC), ): TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID) - dag_maker.sync_dagbag_to_db() with dag_maker( @@ -138,7 +142,45 @@ def make_dag(dag_maker, session, time_machine, asset1: Asset, asset2: Asset, ass >> ExternalTaskSensor(task_id="external_task_sensor", external_dag_id=DAG_ID) >> EmptyOperator(task_id="task_2") ) + dag_maker.sync_dagbag_to_db() + + with dag_maker( + dag_id=DAG_ID_RESOLVED_ASSET_ALIAS, + serialized=True, + session=session, + start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC), + ): + + @task(outlets=[AssetAlias("example-alias-resolved")]) + def task_1(**context): + yield Metadata( + asset=Asset("resolved_example_asset_alias"), + extra={"k": "v"}, # extra has to be provided, can be {} + alias=AssetAlias("example-alias-resolved"), + ) + task_1() >> EmptyOperator(task_id="task_2") + + dr = dag_maker.create_dagrun() + asset_alias = session.scalar( + select(AssetAliasModel).where(AssetAliasModel.name == "example-alias-resolved") + ) + asset_model = AssetModel(name="resolved_example_asset_alias") + session.add(asset_model) + session.flush() + asset_alias.assets.append(asset_model) + asset_alias.asset_events.append( + AssetEvent( + id=1, + timestamp=timezone.parse("2021-01-01T00:00:00"), + asset_id=asset_model.id, + source_dag_id=DAG_ID_RESOLVED_ASSET_ALIAS, + source_task_id="task_1", + source_run_id=dr.run_id, + source_map_index=-1, + ) + ) + session.commit() dag_maker.sync_dagbag_to_db() @@ -151,17 +193,17 @@ def _fetch_asset_id(asset: Asset, session: Session) -> str: @pytest.fixture -def asset1_id(make_dag, asset1, session: Session) -> str: +def asset1_id(make_dags, asset1, session: Session) -> str: return _fetch_asset_id(asset1, session) @pytest.fixture -def asset2_id(make_dag, asset2, session) -> str: +def asset2_id(make_dags, asset2, session) -> str: return _fetch_asset_id(asset2, session) @pytest.fixture -def asset3_id(make_dag, asset3, session) -> str: +def asset3_id(make_dags, asset3, session) -> str: return _fetch_asset_id(asset3, session) @@ -296,13 +338,13 @@ class TestStructureDataEndpoint: ), ], ) - @pytest.mark.usefixtures("make_dag") + @pytest.mark.usefixtures("make_dags") def test_should_return_200(self, test_client, params, expected): response = test_client.get("/structure/structure_data", params=params) assert response.status_code == 200 assert response.json() == expected - @pytest.mark.usefixtures("make_dag") + @pytest.mark.usefixtures("make_dags") def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, asset3_id): params = { "dag_id": DAG_ID, @@ -492,6 +534,75 @@ def test_should_return_200_with_asset(self, test_client, asset1_id, asset2_id, a assert response.status_code == 200 assert response.json() == expected + @pytest.mark.usefixtures("make_dags") + def test_should_return_200_with_resolved_asset_alias_attached_to_the_corrrect_producing_task( + self, test_client, session + ): + resolved_asset = session.scalar( + session.query(AssetModel).filter_by(name="resolved_example_asset_alias") + ) + params = { + "dag_id": DAG_ID_RESOLVED_ASSET_ALIAS, + "external_dependencies": True, + } + expected = { + "edges": [ + { + "source_id": "task_1", + "target_id": "task_2", + "is_setup_teardown": None, + "label": None, + "is_source_asset": None, + }, + { + "source_id": "task_1", + "target_id": f"asset:{resolved_asset.id}", + "is_setup_teardown": None, + "label": None, + "is_source_asset": None, + }, + ], + "nodes": [ + { + "id": "task_1", + "label": "task_1", + "type": "task", + "children": None, + "is_mapped": None, + "tooltip": None, + "setup_teardown_type": None, + "operator": "@task", + "asset_condition_type": None, + }, + { + "id": "task_2", + "label": "task_2", + "type": "task", + "children": None, + "is_mapped": None, + "tooltip": None, + "setup_teardown_type": None, + "operator": "EmptyOperator", + "asset_condition_type": None, + }, + { + "id": f"asset:{resolved_asset.id}", + "label": "resolved_example_asset_alias", + "type": "asset", + "children": None, + "is_mapped": None, + "tooltip": None, + "setup_teardown_type": None, + "operator": None, + "asset_condition_type": None, + }, + ], + } + + response = test_client.get("/structure/structure_data", params=params) + assert response.status_code == 200 + assert response.json() == expected + @pytest.mark.parametrize( "params, expected", [