Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions airflow/assets/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import functools
from typing import TYPE_CHECKING

import attrs

from airflow.models.asset import expand_alias_to_assets, resolve_ref_to_asset
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetBooleanCondition,
AssetRef,
AssetUniqueKey,
BaseAsset,
)
from airflow.sdk.definitions.asset.decorators import MultiAssetDefinition

if TYPE_CHECKING:
from sqlalchemy.orm import Session


@attrs.define
class AssetEvaluator:
"""Evaluates whether an asset-like object has been satisfied."""

_session: Session

def _resolve_asset_ref(self, o: AssetRef) -> Asset | None:
asset = resolve_ref_to_asset(**attrs.asdict(o), session=self._session)
return asset.to_public() if asset else None

def _resolve_asset_alias(self, o: AssetAlias) -> list[Asset]:
asset_models = expand_alias_to_assets(o.name, session=self._session)
return [m.to_public() for m in asset_models]

@functools.singledispatchmethod
def run(self, o: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool:
raise NotImplementedError(f"can not evaluate {o!r}")

@run.register
def _(self, o: Asset, statuses: dict[AssetUniqueKey, bool]) -> bool:
return statuses.get(AssetUniqueKey.from_asset(o), False)

@run.register
def _(self, o: AssetRef, statuses: dict[AssetUniqueKey, bool]) -> bool:
if asset := self._resolve_asset_ref(o):
return self.run(asset, statuses)
return False

@run.register
def _(self, o: AssetAlias, statuses: dict[AssetUniqueKey, bool]) -> bool:
return any(self.run(x, statuses) for x in self._resolve_asset_alias(o))

@run.register
def _(self, o: AssetBooleanCondition, statuses: dict[AssetUniqueKey, bool]) -> bool:
return o.agg_func(self.run(x, statuses) for x in o.objects)

@run.register
def _(self, o: MultiAssetDefinition, statuses: dict[AssetUniqueKey, bool]) -> bool:
return all(self.run(x, statuses) for x in o.iter_outlets())
6 changes: 3 additions & 3 deletions airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def fetch_active_assets_by_uri(uris: Iterable[str], session: Session) -> dict[st
}


def expand_alias_to_assets(alias_name: str, session: Session) -> Iterable[AssetModel]:
def expand_alias_to_assets(alias_name: str, *, session: Session) -> Iterable[AssetModel]:
"""Expand asset alias to resolved assets."""
asset_alias_obj = session.scalar(
select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1)
)
if asset_alias_obj:
return list(asset_alias_obj.assets)
return []
return iter(asset_alias_obj.assets)
return iter(())


def resolve_ref_to_asset(
Expand Down
5 changes: 4 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from sqlalchemy.sql import Select, expression

from airflow import settings, utils
from airflow.assets.evaluation import AssetEvaluator
from airflow.configuration import conf as airflow_conf, secrets_backend_list
from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -2323,12 +2324,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, dateti
"""
from airflow.models.serialized_dag import SerializedDagModel

evaluator = AssetEvaluator(session)

def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses, session=session)
return evaluator.run(cond, statuses)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None
Expand Down
102 changes: 22 additions & 80 deletions task-sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

import contextlib
import logging
import operator
import os
Expand All @@ -34,8 +33,6 @@
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm import Session

from airflow.models.asset import AssetModel
from airflow.serialization.serialized_objects import SerializedAssetWatcher
from airflow.triggers.base import BaseEventTrigger
Expand Down Expand Up @@ -233,9 +230,6 @@ def as_expression(self) -> Any:
"""
raise NotImplementedError

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
raise NotImplementedError

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
raise NotImplementedError

Expand Down Expand Up @@ -442,9 +436,6 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
def iter_asset_refs(self) -> Iterator[AssetRef]:
return iter(())

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
return statuses.get(AssetUniqueKey.from_asset(self), False)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Iterate an asset as dag dependency.
Expand Down Expand Up @@ -489,35 +480,14 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
def iter_asset_refs(self) -> Iterator[AssetRef]:
yield self

def _resolve_asset(self, *, session: Session | None = None) -> Asset | None:
from airflow.models.asset import resolve_ref_to_asset
from airflow.utils.session import create_session

with contextlib.nullcontext(session) if session else create_session() as session:
asset = resolve_ref_to_asset(**attrs.asdict(self), session=session)
return asset.to_public() if asset else None

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
if asset := self._resolve_asset(session=session):
return asset.evaluate(statuses=statuses, session=session)
return False

def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
(dependency_id,) = attrs.astuple(self)
if asset := self._resolve_asset():
yield DagDependency(
source=f"asset-ref:{dependency_id}" if source else "asset",
target="asset" if source else f"asset-ref:{dependency_id}",
dependency_type="asset",
dependency_id=asset.name,
)
else:
yield DagDependency(
source=source or "asset-ref",
target=target or "asset-ref",
dependency_type="asset-ref",
dependency_id=dependency_id,
)
yield DagDependency(
source=source or "asset-ref",
target=target or "asset-ref",
dependency_type="asset-ref",
dependency_id=dependency_id,
)


@attrs.define(hash=True)
Expand Down Expand Up @@ -553,14 +523,6 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier)

def _resolve_assets(self, session: Session | None = None) -> list[Asset]:
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with contextlib.nullcontext(session) if session else create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

def as_expression(self) -> Any:
"""
Serialize the asset alias into its scheduling expression.
Expand All @@ -569,9 +531,6 @@ def as_expression(self) -> Any:
"""
return {"alias": {"name": self.name, "group": self.group}}

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session))

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())

Expand All @@ -587,34 +546,20 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat

:meta private:
"""
if not (resolved_assets := self._resolve_assets()):
yield DagDependency(
source=source or "asset-alias",
target=target or "asset-alias",
dependency_type="asset-alias",
dependency_id=self.name,
)
return
for asset in resolved_assets:
asset_name = asset.name
# asset
yield DagDependency(
source=f"asset-alias:{self.name}" if source else "asset",
target="asset" if source else f"asset-alias:{self.name}",
dependency_type="asset",
dependency_id=asset_name,
)
# asset alias
yield DagDependency(
source=source or f"asset:{asset_name}",
target=target or f"asset:{asset_name}",
dependency_type="asset-alias",
dependency_id=self.name,
)


class _AssetBooleanCondition(BaseAsset):
"""Base class for asset boolean logic."""
yield DagDependency(
source=source or "asset-alias",
target=target or "asset-alias",
dependency_type="asset-alias",
dependency_id=self.name,
)


class AssetBooleanCondition(BaseAsset):
"""
Base class for asset boolean logic.

:meta private:
"""

agg_func: Callable[[Iterable], bool]

Expand All @@ -623,9 +568,6 @@ def __init__(self, *objects: BaseAsset) -> None:
raise TypeError("expect asset expressions in condition")
self.objects = objects

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
for o in self.objects:
yield from o.iter_assets()
Expand All @@ -648,7 +590,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
yield from obj.iter_dag_dependencies(source=source, target=target)


class AssetAny(_AssetBooleanCondition):
class AssetAny(AssetBooleanCondition):
"""Use to combine assets schedule references in an "or" relationship."""

agg_func = any
Expand All @@ -671,7 +613,7 @@ def as_expression(self) -> dict[str, Any]:
return {"any": [o.as_expression() for o in self.objects]}


class AssetAll(_AssetBooleanCondition):
class AssetAll(AssetBooleanCondition):
"""Use to combine assets schedule references in an "and" relationship."""

agg_func = all
Expand Down
9 changes: 4 additions & 5 deletions task-sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Mapping

from sqlalchemy.orm import Session

from airflow.io.path import ObjectStoragePath
from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg
Expand Down Expand Up @@ -122,9 +120,6 @@ def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self._function.__name__):
_AssetMainOperator.from_definition(self)

def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool:
return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
for o in self._source.outlets:
yield from o.iter_assets()
Expand All @@ -141,6 +136,10 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
for obj in self._source.outlets:
yield from obj.iter_dag_dependencies(source=source, target=target)

def iter_outlets(self) -> Iterator[BaseAsset]:
"""For asset evaluation in the scheduler."""
return iter(self._source.outlets)


@attrs.define(kw_only=True)
class _DAGFactory:
Expand Down
Loading