Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
from types import FrameType

from pendulum.datetime import DateTime
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Load, Query, Session

from airflow._shared.logging.types import Logger
from airflow.executors.base_executor import BaseExecutor
Expand All @@ -110,6 +110,31 @@
""":meta private:"""


def _eager_load_dag_run_for_validation() -> tuple[Load, Load]:
"""
Eager-load DagRun relations required for execution API datamodel validation.

When building TaskCallbackRequest with DRDataModel.model_validate(ti.dag_run),
the consumed_asset_events collection and nested asset/source_aliases must be
preloaded to avoid DetachedInstanceError after the session closes.

Returns a tuple of two load options:
- Asset loader: TI.dag_run → consumed_asset_events → asset
- Alias loader: TI.dag_run → consumed_asset_events → source_aliases

Example usage::

asset_loader, alias_loader = _eager_load_dag_run_for_validation()
query = select(TI).options(asset_loader).options(alias_loader)
"""
# Traverse TI → dag_run → consumed_asset_events once, then branch to asset/aliases
base = joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events)
return (
base.selectinload(AssetEvent.asset),
base.selectinload(AssetEvent.source_aliases),
)


def _get_current_dag(dag_id: str, session: Session) -> SerializedDAG | None:
serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version
if not serdag:
Expand Down Expand Up @@ -806,11 +831,12 @@ def process_executor_events(

# Check state of finished tasks
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
asset_loader, _ = _eager_load_dag_run_for_validation()
query = (
select(TI)
.where(filter_for_tis)
.options(selectinload(TI.dag_model))
.options(joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events))
.options(asset_loader)
.options(joinedload(TI.dag_run).selectinload(DagRun.created_dag_version))
.options(joinedload(TI.dag_version))
)
Expand Down Expand Up @@ -2375,10 +2401,12 @@ def _find_and_purge_task_instances_without_heartbeats(self) -> None:
def _find_task_instances_without_heartbeats(self, *, session: Session) -> list[TI]:
self.log.debug("Finding 'running' jobs without a recent heartbeat")
limit_dttm = timezone.utcnow() - timedelta(seconds=self._task_instance_heartbeat_timeout_secs)
asset_loader, alias_loader = _eager_load_dag_run_for_validation()
task_instances_without_heartbeats = session.scalars(
select(TI)
.options(selectinload(TI.dag_model))
.options(selectinload(TI.dag_run).selectinload(DagRun.consumed_asset_events))
.options(asset_loader)
.options(alias_loader)
.options(selectinload(TI.dag_version))
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(DM, TI.dag_id == DM.dag_id)
Expand Down
152 changes: 152 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,67 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca
scheduler_job.executor.callback_sink.send.assert_not_called()
mock_stats_incr.assert_not_called()

@pytest.mark.usefixtures("testing_dag_bundle")
@mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr")
def test_process_executor_events_with_asset_events(self, mock_stats_incr, session, dag_maker):
"""
Test that _process_executor_events handles asset events without DetachedInstanceError.

Regression test for scheduler crashes when task callbacks are built with
consumed_asset_events that weren't eager-loaded.
"""
asset1 = Asset(uri="test://asset1", name="test_asset_executor", group="test_group")
asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group)
session.add(asset_model)
session.flush()

with dag_maker(dag_id="test_executor_events_with_assets", schedule=[asset1], fileloc="/test_path1/"):
EmptyOperator(task_id="dummy_task", on_failure_callback=lambda ctx: None)

dag = dag_maker.dag
sync_dag_to_db(dag)
DagVersion.get_latest_version(dag.dag_id)

dr = dag_maker.create_dagrun()

# Create asset event and attach to dag run
asset_event = AssetEvent(
asset_id=asset_model.id,
source_task_id="upstream_task",
source_dag_id="upstream_dag",
source_run_id="upstream_run",
source_map_index=-1,
)
session.add(asset_event)
session.flush()
dr.consumed_asset_events.append(asset_event)
session.add(dr)
session.flush()

executor = MockExecutor(do_update=False)
scheduler_job = Job(executor=executor)
self.job_runner = SchedulerJobRunner(scheduler_job)

ti1 = dr.get_task_instance("dummy_task")
ti1.state = State.QUEUED
session.merge(ti1)
session.commit()

executor.event_buffer[ti1.key] = State.FAILED, None

# This should not raise DetachedInstanceError
self.job_runner._process_executor_events(executor=executor, session=session)

ti1.refresh_from_db(session=session)
assert ti1.state == State.FAILED

# Verify callback was created with asset event data
scheduler_job.executor.callback_sink.send.assert_called_once()
callback_request = scheduler_job.executor.callback_sink.send.call_args.args[0]
assert callback_request.context_from_server is not None
assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1
assert callback_request.context_from_server.dag_run.consumed_asset_events[0].asset.uri == asset1.uri

def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker):
dag_id = "SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute"
task_id_1 = "dummy_task"
Expand All @@ -628,6 +689,97 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker)
assert ti1.state == State.SCHEDULED
session.rollback()

@pytest.mark.usefixtures("testing_dag_bundle")
def test_find_and_purge_task_instances_without_heartbeats_with_asset_events(
self, session, dag_maker, create_dagrun
):
"""
Test that heartbeat purge succeeds when DagRun has consumed_asset_events.

Regression test for DetachedInstanceError when building TaskCallbackRequest
with asset event data after session expunge.
"""
asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group")
asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group)
session.add(asset_model)
session.flush()

with dag_maker(dag_id="test_heartbeat_with_assets", schedule=[asset1]):
EmptyOperator(task_id="dummy_task")

dag = dag_maker.dag
scheduler_dag = sync_dag_to_db(dag)
dag_v = DagVersion.get_latest_version(dag.dag_id)

data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE)
dag_run = create_dagrun(
scheduler_dag,
logical_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
data_interval=data_interval,
)

# Create asset alias and event with full relationships
asset_alias = AssetAliasModel(name="test_alias", group="test_group")
session.add(asset_alias)
session.flush()

asset_event = AssetEvent(
asset_id=asset_model.id,
source_task_id="upstream_task",
source_dag_id="upstream_dag",
source_run_id="upstream_run",
source_map_index=-1,
)
session.add(asset_event)
session.flush()

# Attach alias to event and event to dag run
asset_event.source_aliases.append(asset_alias)
dag_run.consumed_asset_events.append(asset_event)
session.add_all([asset_event, dag_run])
session.flush()

executor = MockExecutor()
scheduler_job = Job(executor=executor)
with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock:
loader_mock.return_value = executor
self.job_runner = SchedulerJobRunner(job=scheduler_job)

ti = dag_run.get_task_instance("dummy_task")
assert ti is not None # sanity check: dag_maker.create_dagrun created the TI

ti.state = State.RUNNING
ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6)
ti.start_date = timezone.utcnow() - timedelta(minutes=10)
ti.queued_by_job_id = scheduler_job.id
ti.dag_version = dag_v
session.merge(ti)
session.flush()

executor.running.add(ti.key)

tis_without_heartbeats = self.job_runner._find_task_instances_without_heartbeats(session=session)
assert len(tis_without_heartbeats) == 1
ti_from_query = tis_without_heartbeats[0]
ti_key = ti_from_query.key

# Detach all ORM objects to mirror scheduler behaviour after session closes
session.expunge_all()

# This should not raise DetachedInstanceError now that eager loads are in place
self.job_runner._purge_task_instances_without_heartbeats(tis_without_heartbeats, session=session)
assert ti_key not in executor.running

executor.callback_sink.send.assert_called_once()
callback_request = executor.callback_sink.send.call_args.args[0]
assert callback_request.context_from_server is not None
assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1
consumed_event = callback_request.context_from_server.dag_run.consumed_asset_events[0]
assert consumed_event.asset.uri == asset1.uri
assert len(consumed_event.source_aliases) == 1
assert consumed_event.source_aliases[0].name == "test_alias"

# @pytest.mark.usefixtures("mock_executor")
def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker):
"""
Expand Down