diff --git a/airflow/assets/evaluation.py b/airflow/assets/evaluation.py new file mode 100644 index 0000000000000..b1a877d4d542b --- /dev/null +++ b/airflow/assets/evaluation.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +import attrs + +from airflow.models.asset import expand_alias_to_assets, resolve_ref_to_asset +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetBooleanCondition, + AssetRef, + AssetUniqueKey, + BaseAsset, +) +from airflow.sdk.definitions.asset.decorators import MultiAssetDefinition + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +@attrs.define +class AssetEvaluator: + """Evaluates whether an asset-like object has been satisfied.""" + + _session: Session + + def _resolve_asset_ref(self, o: AssetRef) -> Asset | None: + asset = resolve_ref_to_asset(**attrs.asdict(o), session=self._session) + return asset.to_public() if asset else None + + def _resolve_asset_alias(self, o: AssetAlias) -> list[Asset]: + asset_models = expand_alias_to_assets(o.name, session=self._session) + return [m.to_public() for m in asset_models] + + @functools.singledispatchmethod + def run(self, o: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool: + raise NotImplementedError(f"can not evaluate {o!r}") + + @run.register + def _(self, o: Asset, statuses: dict[AssetUniqueKey, bool]) -> bool: + return statuses.get(AssetUniqueKey.from_asset(o), False) + + @run.register + def _(self, o: AssetRef, statuses: dict[AssetUniqueKey, bool]) -> bool: + if asset := self._resolve_asset_ref(o): + return self.run(asset, statuses) + return False + + @run.register + def _(self, o: AssetAlias, statuses: dict[AssetUniqueKey, bool]) -> bool: + return any(self.run(x, statuses) for x in self._resolve_asset_alias(o)) + + @run.register + def _(self, o: AssetBooleanCondition, statuses: dict[AssetUniqueKey, bool]) -> bool: + return o.agg_func(self.run(x, statuses) for x in o.objects) + + @run.register + def _(self, o: MultiAssetDefinition, statuses: dict[AssetUniqueKey, bool]) -> bool: + return all(self.run(x, statuses) for x in o.iter_outlets()) diff --git a/airflow/models/asset.py b/airflow/models/asset.py index 212a0b3a84c6c..5ec0b2e977cba 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -70,14 +70,14 @@ def fetch_active_assets_by_uri(uris: Iterable[str], session: Session) -> dict[st } -def expand_alias_to_assets(alias_name: str, session: Session) -> Iterable[AssetModel]: +def expand_alias_to_assets(alias_name: str, *, session: Session) -> Iterable[AssetModel]: """Expand asset alias to resolved assets.""" asset_alias_obj = session.scalar( select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) ) if asset_alias_obj: - return list(asset_alias_obj.assets) - return [] + return iter(asset_alias_obj.assets) + return iter(()) def resolve_ref_to_asset( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 16cbf28d64706..cbb73e45f8402 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -68,6 +68,7 @@ from sqlalchemy.sql import Select, expression from airflow import settings, utils +from airflow.assets.evaluation import AssetEvaluator from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowException, @@ -2323,12 +2324,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, dateti """ from airflow.models.serialized_dag import SerializedDagModel + evaluator = AssetEvaluator(session) + def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: - return cond.evaluate(statuses, session=session) + return evaluator.run(cond, statuses) except AttributeError: log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 11b7a1b2daefe..94559988c9d4a 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,7 +17,6 @@ from __future__ import annotations -import contextlib import logging import operator import os @@ -34,8 +33,6 @@ from collections.abc import Iterable, Iterator from urllib.parse import SplitResult - from sqlalchemy.orm import Session - from airflow.models.asset import AssetModel from airflow.serialization.serialized_objects import SerializedAssetWatcher from airflow.triggers.base import BaseEventTrigger @@ -233,9 +230,6 @@ def as_expression(self) -> Any: """ raise NotImplementedError - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - raise NotImplementedError - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: raise NotImplementedError @@ -442,9 +436,6 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: def iter_asset_refs(self) -> Iterator[AssetRef]: return iter(()) - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return statuses.get(AssetUniqueKey.from_asset(self), False) - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ Iterate an asset as dag dependency. @@ -489,35 +480,14 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: def iter_asset_refs(self) -> Iterator[AssetRef]: yield self - def _resolve_asset(self, *, session: Session | None = None) -> Asset | None: - from airflow.models.asset import resolve_ref_to_asset - from airflow.utils.session import create_session - - with contextlib.nullcontext(session) if session else create_session() as session: - asset = resolve_ref_to_asset(**attrs.asdict(self), session=session) - return asset.to_public() if asset else None - - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - if asset := self._resolve_asset(session=session): - return asset.evaluate(statuses=statuses, session=session) - return False - def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: (dependency_id,) = attrs.astuple(self) - if asset := self._resolve_asset(): - yield DagDependency( - source=f"asset-ref:{dependency_id}" if source else "asset", - target="asset" if source else f"asset-ref:{dependency_id}", - dependency_type="asset", - dependency_id=asset.name, - ) - else: - yield DagDependency( - source=source or "asset-ref", - target=target or "asset-ref", - dependency_type="asset-ref", - dependency_id=dependency_id, - ) + yield DagDependency( + source=source or "asset-ref", + target=target or "asset-ref", + dependency_type="asset-ref", + dependency_id=dependency_id, + ) @attrs.define(hash=True) @@ -553,14 +523,6 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier) - def _resolve_assets(self, session: Session | None = None) -> list[Asset]: - from airflow.models.asset import expand_alias_to_assets - from airflow.utils.session import create_session - - with contextlib.nullcontext(session) if session else create_session() as session: - asset_models = expand_alias_to_assets(self.name, session) - return [m.to_public() for m in asset_models] - def as_expression(self) -> Any: """ Serialize the asset alias into its scheduling expression. @@ -569,9 +531,6 @@ def as_expression(self) -> Any: """ return {"alias": {"name": self.name, "group": self.group}} - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session)) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) @@ -587,34 +546,20 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat :meta private: """ - if not (resolved_assets := self._resolve_assets()): - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - return - for asset in resolved_assets: - asset_name = asset.name - # asset - yield DagDependency( - source=f"asset-alias:{self.name}" if source else "asset", - target="asset" if source else f"asset-alias:{self.name}", - dependency_type="asset", - dependency_id=asset_name, - ) - # asset alias - yield DagDependency( - source=source or f"asset:{asset_name}", - target=target or f"asset:{asset_name}", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - -class _AssetBooleanCondition(BaseAsset): - """Base class for asset boolean logic.""" + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + + +class AssetBooleanCondition(BaseAsset): + """ + Base class for asset boolean logic. + + :meta private: + """ agg_func: Callable[[Iterable], bool] @@ -623,9 +568,6 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = objects - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self.objects: yield from o.iter_assets() @@ -648,7 +590,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe yield from obj.iter_dag_dependencies(source=source, target=target) -class AssetAny(_AssetBooleanCondition): +class AssetAny(AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" agg_func = any @@ -671,7 +613,7 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class AssetAll(_AssetBooleanCondition): +class AssetAll(AssetBooleanCondition): """Use to combine assets schedule references in an "and" relationship.""" agg_func = all diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index c70a224858b1b..77ab57074bbfe 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -28,8 +28,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterator, Mapping - from sqlalchemy.orm import Session - from airflow.io.path import ObjectStoragePath from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg @@ -122,9 +120,6 @@ def __attrs_post_init__(self) -> None: with self._source.create_dag(dag_id=self._function.__name__): _AssetMainOperator.from_definition(self) - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self._source.outlets: yield from o.iter_assets() @@ -141,6 +136,10 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe for obj in self._source.outlets: yield from obj.iter_dag_dependencies(source=source, target=target) + def iter_outlets(self) -> Iterator[BaseAsset]: + """For asset evaluation in the scheduler.""" + return iter(self._source.outlets) + @attrs.define(kw_only=True) class _DAGFactory: diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 767cd9e1be714..1637cffac61be 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -37,7 +37,7 @@ _sanitize_uri, ) from airflow.sdk.definitions.dag import DAG -from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +from airflow.serialization.serialized_objects import SerializedDAG ASSET_MODULE_PATH = "airflow.sdk.definitions.asset" @@ -185,18 +185,6 @@ def test_asset_iter_asset_aliases(): ] -@pytest.mark.parametrize( - "statuses, result", - [ - ({AssetUniqueKey.from_asset(asset1): True}, True), - ({AssetUniqueKey.from_asset(asset1): False}, False), - ({}, False), - ], -) -def test_asset_evaluate(statuses, result): - assert asset1.evaluate(statuses) is result - - def test_asset_any_operations(): result_or = (asset1 | asset2) | asset3 assert isinstance(result_or, AssetAny) @@ -212,116 +200,6 @@ def test_asset_all_operations(): assert isinstance(result_and, AssetAll) -@pytest.mark.parametrize( - "condition, statuses, result", - [ - ( - AssetAny(asset1, asset2), - {AssetUniqueKey.from_asset(asset1): False, AssetUniqueKey.from_asset(asset2): True}, - True, - ), - ( - AssetAll(asset1, asset2), - {AssetUniqueKey.from_asset(asset1): True, AssetUniqueKey.from_asset(asset2): False}, - False, - ), - ], -) -def test_assset_boolean_condition_evaluate_iter(condition, statuses, result): - """ - Tests _AssetBooleanCondition's evaluate and iter_assets methods through AssetAny and AssetAll. - Ensures AssetAny evaluate returns True with any true condition, AssetAll evaluate returns False if - any condition is false, and both classes correctly iterate over assets without duplication. - """ - assert condition.evaluate(statuses) is result - assert dict(condition.iter_assets()) == { - AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, - AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, - } - - -@pytest.mark.parametrize( - "inputs, scenario, expected", - [ - # Scenarios for AssetAny - ((True, True, True), "any", True), - ((True, True, False), "any", True), - ((True, False, True), "any", True), - ((True, False, False), "any", True), - ((False, False, True), "any", True), - ((False, True, False), "any", True), - ((False, True, True), "any", True), - ((False, False, False), "any", False), - # Scenarios for AssetAll - ((True, True, True), "all", True), - ((True, True, False), "all", False), - ((True, False, True), "all", False), - ((True, False, False), "all", False), - ((False, False, True), "all", False), - ((False, True, False), "all", False), - ((False, True, True), "all", False), - ((False, False, False), "all", False), - ], -) -def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): - class_ = AssetAny if scenario == "any" else AssetAll - assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] - condition = class_(*assets) - - statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in zip(assets, inputs)} - assert ( - condition.evaluate(statuses) == expected - ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" - - # Serialize and deserialize the condition to test persistence - serialized = BaseSerialization.serialize(condition) - deserialized = BaseSerialization.deserialize(serialized) - assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" - - -@pytest.mark.parametrize( - "status_values, expected_evaluation", - [ - ( - (False, True, True), - False, - ), # AssetAll requires all conditions to be True, but asset1 is False - ((True, True, True), True), # All conditions are True - ( - (True, False, True), - True, - ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met - ( - (True, False, False), - False, - ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition - ], -) -def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): - # Define assets - asset1 = Asset(uri="s3://abc/123") - asset2 = Asset(uri="s3://abc/124") - asset3 = Asset(uri="s3://abc/125") - - # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 - nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) - - statuses = { - AssetUniqueKey.from_asset(asset1): status_values[0], - AssetUniqueKey.from_asset(asset2): status_values[1], - AssetUniqueKey.from_asset(asset3): status_values[2], - } - - assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" - - serialized_condition = BaseSerialization.serialize(nested_condition) - deserialized_condition = BaseSerialization.deserialize(serialized_condition) - - assert ( - deserialized_condition.evaluate(statuses) == expected_evaluation - ), "Post-serialization evaluation mismatch" - - @pytest.fixture def create_test_assets(): """Fixture to create test assets and corresponding models.""" @@ -500,38 +378,10 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme): class TestAssetAlias: - @pytest.fixture - def asset(self): - """Example asset links to asset alias resolved_asset_alias_2.""" - return Asset(uri="test://asset1/", name="test_name", group="asset") - - @pytest.fixture - def asset_alias_1(self): - """Example asset alias links to no assets.""" - asset_alias_1 = AssetAlias(name="test_name", group="test") - with mock.patch.object(asset_alias_1, "_resolve_assets", return_value=[]): - yield asset_alias_1 - - @pytest.fixture - def resolved_asset_alias_2(self, asset): - """Example asset alias links to asset.""" - asset_alias_2 = AssetAlias(name="test_name_2") - with mock.patch.object(asset_alias_2, "_resolve_assets", return_value=[asset]): - yield asset_alias_2 - - @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", "resolved_asset_alias_2"]) - def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name): - alias = request.getfixturevalue(alias_fixture_name) + def test_as_expression(self): + alias = AssetAlias(name="test_name", group="test") assert alias.as_expression() == {"alias": {"name": alias.name, "group": alias.group}} - def test_evalute_empty(self, asset_alias_1, asset): - assert asset_alias_1.evaluate({AssetUniqueKey.from_asset(asset): True}) is False - assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)] - - def test_evalute_resolved(self, resolved_asset_alias_2, asset): - assert resolved_asset_alias_2.evaluate({AssetUniqueKey.from_asset(asset): True}) is True - assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call(None)] - class TestAssetSubclasses: @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset"))) diff --git a/tests/assets/test_evaluation.py b/tests/assets/test_evaluation.py new file mode 100644 index 0000000000000..1c8e909eee180 --- /dev/null +++ b/tests/assets/test_evaluation.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.assets.evaluation import AssetEvaluator +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetUniqueKey +from airflow.serialization.serialized_objects import BaseSerialization + +pytestmark = pytest.mark.db_test + +asset1 = Asset(uri="s3://bucket1/data1", name="asset-1") +asset2 = Asset(uri="s3://bucket2/data2", name="asset-2") + + +@pytest.fixture +def evaluator(session): + return AssetEvaluator(session) + + +@pytest.mark.parametrize( + "statuses, result", + [ + ({AssetUniqueKey.from_asset(asset1): True}, True), + ({AssetUniqueKey.from_asset(asset1): False}, False), + ({}, False), + ], +) +def test_asset_evaluate(evaluator, statuses, result): + assert evaluator.run(asset1, statuses) is result + + +@pytest.mark.parametrize( + "condition, statuses, result", + [ + ( + AssetAny(asset1, asset2), + {AssetUniqueKey.from_asset(asset1): False, AssetUniqueKey.from_asset(asset2): True}, + True, + ), + ( + AssetAll(asset1, asset2), + {AssetUniqueKey.from_asset(asset1): True, AssetUniqueKey.from_asset(asset2): False}, + False, + ), + ], +) +def test_assset_boolean_condition_evaluate_iter(evaluator, condition, statuses, result): + """ + Tests _AssetBooleanCondition's evaluate and iter_assets methods through AssetAny and AssetAll. + + Ensures AssetAny evaluate returns True with any true condition, AssetAll evaluate returns False if + any condition is false, and both classes correctly iterate over assets without duplication. + """ + assert evaluator.run(condition, statuses) is result + assert dict(condition.iter_assets()) == { + AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, + AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, + } + + +@pytest.mark.parametrize( + "inputs, scenario, expected", + [ + # Scenarios for AssetAny + ((True, True, True), "any", True), + ((True, True, False), "any", True), + ((True, False, True), "any", True), + ((True, False, False), "any", True), + ((False, False, True), "any", True), + ((False, True, False), "any", True), + ((False, True, True), "any", True), + ((False, False, False), "any", False), + # Scenarios for AssetAll + ((True, True, True), "all", True), + ((True, True, False), "all", False), + ((True, False, True), "all", False), + ((True, False, False), "all", False), + ((False, False, True), "all", False), + ((False, True, False), "all", False), + ((False, True, True), "all", False), + ((False, False, False), "all", False), + ], +) +def test_asset_logical_conditions_evaluation_and_serialization(evaluator, inputs, scenario, expected): + class_ = AssetAny if scenario == "any" else AssetAll + assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] + condition = class_(*assets) + + statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in zip(assets, inputs)} + assert ( + evaluator.run(condition, statuses) == expected + ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" + + # Serialize and deserialize the condition to test persistence + serialized = BaseSerialization.serialize(condition) + deserialized = BaseSerialization.deserialize(serialized) + assert evaluator.run(deserialized, statuses) == expected, "Serialization round-trip failed" + + +@pytest.mark.parametrize( + "status_values, expected_evaluation", + [ + pytest.param( + (False, True, True), + False, + id="f & (t | t)", + ), # AssetAll requires all conditions to be True, but asset1 is False + pytest.param( + (True, True, True), + True, + id="t & (t | t)", + ), # All conditions are True + pytest.param( + (True, False, True), + True, + id="t & (f | t)", + ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met + pytest.param( + (True, False, False), + False, + id="t & (f | f)", + ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition + ], +) +def test_nested_asset_conditions_with_serialization(evaluator, status_values, expected_evaluation): + # Define assets + asset1 = Asset(uri="s3://abc/123") + asset2 = Asset(uri="s3://abc/124") + asset3 = Asset(uri="s3://abc/125") + + # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 + nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) + + statuses = { + AssetUniqueKey.from_asset(asset1): status_values[0], + AssetUniqueKey.from_asset(asset2): status_values[1], + AssetUniqueKey.from_asset(asset3): status_values[2], + } + + assert evaluator.run(nested_condition, statuses) == expected_evaluation, "Initial evaluation mismatch" + + serialized_condition = BaseSerialization.serialize(nested_condition) + deserialized_condition = BaseSerialization.deserialize(serialized_condition) + + assert ( + evaluator.run(deserialized_condition, statuses) == expected_evaluation + ), "Post-serialization evaluation mismatch" + + +class TestAssetAlias: + @pytest.fixture + def asset(self): + """Example asset links to asset alias resolved_asset_alias_2.""" + return Asset(uri="test://asset1/", name="test_name", group="asset") + + @pytest.fixture + def asset_alias_1(self): + """Example asset alias links to no assets.""" + return AssetAlias(name="test_name", group="test") + + @pytest.fixture + def resolved_asset_alias_2(self): + """Example asset alias links to asset.""" + return AssetAlias(name="test_name_2") + + @pytest.fixture + def evaluator(self, session, asset_alias_1, resolved_asset_alias_2, asset): + class _AssetEvaluator(AssetEvaluator): # Can't use mock because AssetEvaluator sets __slots__. + def _resolve_asset_alias(self, o): + if o is asset_alias_1: + return [] + elif o is resolved_asset_alias_2: + return [asset] + return super()._resolve_asset_alias(o) + + return _AssetEvaluator(session) + + def test_evaluate_empty(self, evaluator, asset_alias_1, asset): + assert evaluator.run(asset_alias_1, {AssetUniqueKey.from_asset(asset): True}) is False + + def test_evalute_resolved(self, evaluator, resolved_asset_alias_2, asset): + assert evaluator.run(resolved_asset_alias_2, {AssetUniqueKey.from_asset(asset): True}) is True diff --git a/tests/models/test_asset.py b/tests/models/test_asset.py index 1e21252c5028f..5608d4c8d59e0 100644 --- a/tests/models/test_asset.py +++ b/tests/models/test_asset.py @@ -72,7 +72,7 @@ def resolved_asset_alias_2(self, session, asset_model): return asset_alias_2 def test_expand_alias_to_assets_empty(self, session, asset_alias_1): - assert expand_alias_to_assets(asset_alias_1.name, session) == [] + assert list(expand_alias_to_assets(asset_alias_1.name, session=session)) == [] def test_expand_alias_to_assets_resolved(self, session, resolved_asset_alias_2, asset_model): - assert expand_alias_to_assets(resolved_asset_alias_2.name, session) == [asset_model] + assert list(expand_alias_to_assets(resolved_asset_alias_2.name, session=session)) == [asset_model] diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index 9892b5805bd8d..d8386bc27f5d9 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -273,8 +273,11 @@ def create_test_assets(self): return [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in range(1, 3)] def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_assets): + from airflow.assets.evaluation import AssetEvaluator + assets = create_test_assets asset_models = session.query(AssetModel).all() + evaluator = AssetEvaluator(session) with dag_maker(schedule=AssetAny(*assets)) as dag: EmptyOperator(task_id="hello") @@ -298,7 +301,7 @@ def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_as dag = SerializedDAG.deserialize(serialized_dag.data) for asset_uri, status in dag_statuses[dag.dag_id].items(): cond = dag.timetable.asset_condition - assert cond.evaluate({asset_uri: status}), "DAG trigger evaluation failed" + assert evaluator.run(cond, {asset_uri: status}), "DAG trigger evaluation failed" def test_dag_with_complex_asset_condition(self, session, dag_maker): # Create Asset instances