Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't commit for read-only query #44905

Merged
merged 2 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,7 +2281,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None:
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
return cond.evaluate(statuses, session=session)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None
Expand Down
3 changes: 2 additions & 1 deletion airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy.orm import Session

from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.serialization.dag_dependency import DagDependency
Expand Down Expand Up @@ -52,7 +53,7 @@ def __and__(self, other: BaseAsset) -> BaseAsset:
def as_expression(self) -> Any:
return None

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return False

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
Expand Down
19 changes: 11 additions & 8 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import contextlib
import logging
import operator
import os
Expand All @@ -32,6 +33,8 @@
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm import Session

from airflow.models.asset import AssetModel
from airflow.triggers.base import BaseTrigger

Expand Down Expand Up @@ -227,7 +230,7 @@ def as_expression(self) -> Any:
"""
raise NotImplementedError

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
raise NotImplementedError

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
Expand Down Expand Up @@ -385,7 +388,7 @@ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
return iter(())

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return statuses.get(self.uri, False)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
Expand Down Expand Up @@ -428,11 +431,11 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier)

def _resolve_assets(self) -> list[Asset]:
def _resolve_assets(self, session: Session | None = None) -> list[Asset]:
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with create_session() as session:
with contextlib.nullcontext(session) if session else create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

Expand All @@ -444,8 +447,8 @@ def as_expression(self) -> Any:
"""
return {"alias": {"name": self.name, "group": self.group}}

def evaluate(self, statuses: dict[str, bool]) -> bool:
return any(x.evaluate(statuses=statuses) for x in self._resolve_assets())
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session))

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())
Expand Down Expand Up @@ -495,8 +498,8 @@ def __init__(self, *objects: BaseAsset) -> None:
raise TypeError("expect asset expressions in condition")
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
seen: set[AssetUniqueKey] = set() # We want to keep the first instance.
Expand Down
6 changes: 4 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Mapping

from sqlalchemy.orm import Session

from airflow.io.path import ObjectStoragePath
from airflow.models.param import ParamsDict
from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
Expand Down Expand Up @@ -120,8 +122,8 @@ def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self._function.__name__):
_AssetMainOperator.from_definition(self)

def evaluate(self, statuses: dict[str, bool]) -> bool:
return all(o.evaluate(statuses=statuses) for o in self._source.outlets)
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
for o in self._source.outlets:
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/tests/defintions/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,11 @@ def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name)

def test_evalute_empty(self, asset_alias_1, asset):
assert asset_alias_1.evaluate({asset.uri: True}) is False
assert asset_alias_1._resolve_assets.mock_calls == [mock.call()]
assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)]

def test_evalute_resolved(self, resolved_asset_alias_2, asset):
assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True
assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()]
assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call(None)]


class TestAssetSubclasses:
Expand Down