From 796b0819c42df3487e0b7a6cfd240a4c995d9496 Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Sat, 27 Dec 2025 17:40:17 +0000 Subject: [PATCH] Added validation for consumed_asset_event for DagRunContext. Unit tests included. --- .../airflow/callbacks/callback_requests.py | 69 ++++++++++++++++++- .../unit/callbacks/test_callback_requests.py | 62 +++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/callbacks/callback_requests.py b/airflow-core/src/airflow/callbacks/callback_requests.py index d2bdf0968bc20..ce48438f0e70a 100644 --- a/airflow-core/src/airflow/callbacks/callback_requests.py +++ b/airflow-core/src/airflow/callbacks/callback_requests.py @@ -16,9 +16,15 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Literal +from collections.abc import Mapping +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast -from pydantic import BaseModel, Field +import structlog +from pydantic import BaseModel, Field, model_validator +from sqlalchemy import inspect as sa_inspect +from sqlalchemy.exc import NoInspectionAvailable +from sqlalchemy.orm.attributes import set_committed_value +from sqlalchemy.orm.exc import DetachedInstanceError from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001 from airflow.utils.state import TaskInstanceState @@ -26,6 +32,8 @@ if TYPE_CHECKING: from airflow.typing_compat import Self +log = structlog.get_logger(logger_name=__name__) + class BaseCallbackRequest(BaseModel): """ @@ -95,6 +103,63 @@ class DagRunContext(BaseModel): dag_run: ti_datamodel.DagRun | None = None last_ti: ti_datamodel.TaskInstance | None = None + @model_validator(mode="before") + @classmethod + def _sanitize_consumed_asset_events(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: + if (dag_run := values.get("dag_run")) is None: + return values + + # DagRunContext may receive non-ORM dag_run objects (e.g. datamodels). + # Only apply this validator to ORM-mapped instances. + try: + sa_inspect(dag_run) + except NoInspectionAvailable: + return values + + # Relationship access may raise DetachedInstanceError; on that path, reload DagRun + # from the DB to avoid crashing the scheduler. + try: + events = dag_run.consumed_asset_events + set_committed_value( + dag_run, + "consumed_asset_events", + list(events) if events is not None else [], + ) + except DetachedInstanceError: + log.warning( + "DagRunContext encountered DetachedInstanceError while accessing " + "consumed_asset_events; reloading DagRun from DB." + ) + from sqlalchemy import select + from sqlalchemy.orm import selectinload + + from airflow.models.asset import AssetEvent + from airflow.models.dagrun import DagRun + from airflow.utils.session import create_session + + # Defensive guardrail: reload DagRun with eager-loaded relationships on + # DetachedInstanceError to recover state without adding DB I/O to the hot path. + with create_session() as session: + dag_run_reloaded = session.scalar( + select(DagRun) + .where(DagRun.id == dag_run.id) + .options( + selectinload(DagRun.consumed_asset_events).selectinload(AssetEvent.asset), + selectinload(DagRun.consumed_asset_events).selectinload(AssetEvent.source_aliases), + ) + ) + + # DagRun exists; reload is expected to succeed. + dag_run_reloaded = cast("DagRun", dag_run_reloaded) + reloaded_events = dag_run_reloaded.consumed_asset_events + + # Install DB-backed relationship state on the detached instance. + set_committed_value( + dag_run, "consumed_asset_events", list(reloaded_events) if reloaded_events is not None else [] + ) + + return values + class DagCallbackRequest(BaseCallbackRequest): """A Class with information about the success/failure DAG callback to be executed.""" diff --git a/airflow-core/tests/unit/callbacks/test_callback_requests.py b/airflow-core/tests/unit/callbacks/test_callback_requests.py index 6c646ee6c9954..434223e37475f 100644 --- a/airflow-core/tests/unit/callbacks/test_callback_requests.py +++ b/airflow-core/tests/unit/callbacks/test_callback_requests.py @@ -34,6 +34,7 @@ EmailRequest, TaskCallbackRequest, ) +from airflow.models import DagRun from airflow.models.taskinstance import TaskInstance from airflow.serialization.definitions.baseoperator import SerializedBaseOperator from airflow.utils.state import State, TaskInstanceState @@ -197,6 +198,67 @@ def test_dagrun_context_serialization(self): assert deserialized.dag_run.dag_id == context.dag_run.dag_id assert deserialized.last_ti.task_id == context.last_ti.task_id + def test_dagrun_context_detached_consumed_asset_events(self, session): + """ + DagRunContext should not fail if a detached DagRun raises + DetachedInstanceError when accessing consumed_asset_events. + """ + # Create a real ORM DagRun. + current_time = timezone.utcnow() + dag_run = DagRun( + dag_id="test_dag", + run_id="test_run_detached", + logical_date=current_time, + state="running", + run_type="manual", + ) + + # Forcefully detached it to replicate failure mode. + session.add(dag_run) + session.commit() + session.expunge(dag_run) + + # Validation for consumed_asset_events occurs on creation of DagRunContext. + context = DagRunContext(dag_run=dag_run, last_ti=None) + + # Access should be safe and not raise DetachedInstanceError. + events = context.dag_run.consumed_asset_events + + # Relationship should be normalized to a safe iterable. + assert events is not None + assert isinstance(events, list) + + def test_dagrun_context_attached_consumed_asset_events(self, session): + """ + DagRunContext should safely normalize consumed_asset_events + when the DagRun is attached to a session. + """ + current_time = timezone.utcnow() + dag_run = DagRun( + dag_id="test_dag", + run_id="test_run_attached", + logical_date=current_time, + state="running", + run_type="manual", + ) + + # Do not detach + session.add(dag_run) + session.flush() + + # Construct context while DagRun is still attached. + context = DagRunContext( + dag_run=dag_run, + last_ti=None, + ) + + # Access should be safe and not raise DetachedInstanceError. + events = context.dag_run.consumed_asset_events + + # Relationship should be normalized to a safe iterable. + assert events is not None + assert isinstance(events, list) + class TestDagCallbackRequestWithContext: def test_dag_callback_request_with_context_from_server(self):