From df68804b192a27488ad828a9e9da9d2c7df3e1d5 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 13 Dec 2024 15:54:44 +0800 Subject: [PATCH 1/2] Don't commit for read-only query --- task_sdk/src/airflow/sdk/definitions/asset/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index a0beb24150f7d..289a0be085c00 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -429,10 +429,13 @@ class AssetAlias(BaseAsset): group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier) def _resolve_assets(self) -> list[Asset]: + from airflow import settings from airflow.models.asset import expand_alias_to_assets - from airflow.utils.session import create_session - with create_session() as session: + # We don't use create_session or provide_session here because this is + # called in the scheduler when commit is prohibited (for HA reasons). + # Those functions assume you want to commit the result on exit. + with settings.Session() as session: asset_models = expand_alias_to_assets(self.name, session) return [m.to_public() for m in asset_models] From 6ccc74a24b6270586dd8ea21c5eba361a9421684 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 13 Dec 2024 17:57:40 +0800 Subject: [PATCH 2/2] Pass session from outside when we can This does not use the create_session/provide_session/NEW_SESSION paradigm because it requires importing them globally, which is not allowed in the SDK. We also do not want to create a session unless absolutely needed. --- airflow/models/dag.py | 2 +- airflow/timetables/base.py | 3 ++- .../airflow/sdk/definitions/asset/__init__.py | 24 +++++++++---------- .../sdk/definitions/asset/decorators.py | 6 +++-- task_sdk/tests/defintions/test_asset.py | 4 ++-- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4d988367073da..6163580870c4e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2281,7 +2281,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None: # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: - return cond.evaluate(statuses) + return cond.evaluate(statuses, session=session) except AttributeError: log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index b80a6323a8c18..c719f92437be1 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from pendulum import DateTime + from sqlalchemy.orm import Session from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.serialization.dag_dependency import DagDependency @@ -52,7 +53,7 @@ def __and__(self, other: BaseAsset) -> BaseAsset: def as_expression(self) -> Any: return None - def evaluate(self, statuses: dict[str, bool]) -> bool: + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: return False def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 289a0be085c00..c7d3906cf2e56 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations +import contextlib import logging import operator import os @@ -32,6 +33,8 @@ from collections.abc import Iterable, Iterator from urllib.parse import SplitResult + from sqlalchemy.orm import Session + from airflow.models.asset import AssetModel from airflow.triggers.base import BaseTrigger @@ -227,7 +230,7 @@ def as_expression(self) -> Any: """ raise NotImplementedError - def evaluate(self, statuses: dict[str, bool]) -> bool: + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: raise NotImplementedError def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: @@ -385,7 +388,7 @@ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) - def evaluate(self, statuses: dict[str, bool]) -> bool: + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: return statuses.get(self.uri, False) def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: @@ -428,14 +431,11 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier) - def _resolve_assets(self) -> list[Asset]: - from airflow import settings + def _resolve_assets(self, session: Session | None = None) -> list[Asset]: from airflow.models.asset import expand_alias_to_assets + from airflow.utils.session import create_session - # We don't use create_session or provide_session here because this is - # called in the scheduler when commit is prohibited (for HA reasons). - # Those functions assume you want to commit the result on exit. - with settings.Session() as session: + with contextlib.nullcontext(session) if session else create_session() as session: asset_models = expand_alias_to_assets(self.name, session) return [m.to_public() for m in asset_models] @@ -447,8 +447,8 @@ def as_expression(self) -> Any: """ return {"alias": {"name": self.name, "group": self.group}} - def evaluate(self, statuses: dict[str, bool]) -> bool: - return any(x.evaluate(statuses=statuses) for x in self._resolve_assets()) + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: + return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session)) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) @@ -498,8 +498,8 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = objects - def evaluate(self, statuses: dict[str, bool]) -> bool: - return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: + return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: seen: set[AssetUniqueKey] = set() # We want to keep the first instance. diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py index 1cb1ea4e31696..531e097fd99f8 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -28,6 +28,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterator, Mapping + from sqlalchemy.orm import Session + from airflow.io.path import ObjectStoragePath from airflow.models.param import ParamsDict from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey @@ -120,8 +122,8 @@ def __attrs_post_init__(self) -> None: with self._source.create_dag(dag_id=self._function.__name__): _AssetMainOperator.from_definition(self) - def evaluate(self, statuses: dict[str, bool]) -> bool: - return all(o.evaluate(statuses=statuses) for o in self._source.outlets) + def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool: + return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self._source.outlets: diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index afdfb37a5fd5d..82225a619664d 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -514,11 +514,11 @@ def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name) def test_evalute_empty(self, asset_alias_1, asset): assert asset_alias_1.evaluate({asset.uri: True}) is False - assert asset_alias_1._resolve_assets.mock_calls == [mock.call()] + assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)] def test_evalute_resolved(self, resolved_asset_alias_2, asset): assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True - assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()] + assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call(None)] class TestAssetSubclasses: