diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index ddab7dd971f10..7be7150a74714 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -93,7 +93,7 @@ from types import FrameType from pendulum.datetime import DateTime - from sqlalchemy.orm import Query, Session + from sqlalchemy.orm import Load, Query, Session from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor @@ -110,6 +110,31 @@ """:meta private:""" +def _eager_load_dag_run_for_validation() -> tuple[Load, Load]: + """ + Eager-load DagRun relations required for execution API datamodel validation. + + When building TaskCallbackRequest with DRDataModel.model_validate(ti.dag_run), + the consumed_asset_events collection and nested asset/source_aliases must be + preloaded to avoid DetachedInstanceError after the session closes. + + Returns a tuple of two load options: + - Asset loader: TI.dag_run → consumed_asset_events → asset + - Alias loader: TI.dag_run → consumed_asset_events → source_aliases + + Example usage:: + + asset_loader, alias_loader = _eager_load_dag_run_for_validation() + query = select(TI).options(asset_loader).options(alias_loader) + """ + # Traverse TI → dag_run → consumed_asset_events once, then branch to asset/aliases + base = joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events) + return ( + base.selectinload(AssetEvent.asset), + base.selectinload(AssetEvent.source_aliases), + ) + + def _get_current_dag(dag_id: str, session: Session) -> SerializedDAG | None: serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version if not serdag: @@ -806,11 +831,12 @@ def process_executor_events( # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) + asset_loader, _ = _eager_load_dag_run_for_validation() query = ( select(TI) .where(filter_for_tis) .options(selectinload(TI.dag_model)) - .options(joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events)) + .options(asset_loader) .options(joinedload(TI.dag_run).selectinload(DagRun.created_dag_version)) .options(joinedload(TI.dag_version)) ) @@ -2375,10 +2401,12 @@ def _find_and_purge_task_instances_without_heartbeats(self) -> None: def _find_task_instances_without_heartbeats(self, *, session: Session) -> list[TI]: self.log.debug("Finding 'running' jobs without a recent heartbeat") limit_dttm = timezone.utcnow() - timedelta(seconds=self._task_instance_heartbeat_timeout_secs) + asset_loader, alias_loader = _eager_load_dag_run_for_validation() task_instances_without_heartbeats = session.scalars( select(TI) .options(selectinload(TI.dag_model)) - .options(selectinload(TI.dag_run).selectinload(DagRun.consumed_asset_events)) + .options(asset_loader) + .options(alias_loader) .options(selectinload(TI.dag_version)) .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") .join(DM, TI.dag_id == DM.dag_id) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 7c03dd4fa1226..9fd022b744c4c 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -608,6 +608,67 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca scheduler_job.executor.callback_sink.send.assert_not_called() mock_stats_incr.assert_not_called() + @pytest.mark.usefixtures("testing_dag_bundle") + @mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr") + def test_process_executor_events_with_asset_events(self, mock_stats_incr, session, dag_maker): + """ + Test that _process_executor_events handles asset events without DetachedInstanceError. + + Regression test for scheduler crashes when task callbacks are built with + consumed_asset_events that weren't eager-loaded. + """ + asset1 = Asset(uri="test://asset1", name="test_asset_executor", group="test_group") + asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group) + session.add(asset_model) + session.flush() + + with dag_maker(dag_id="test_executor_events_with_assets", schedule=[asset1], fileloc="/test_path1/"): + EmptyOperator(task_id="dummy_task", on_failure_callback=lambda ctx: None) + + dag = dag_maker.dag + sync_dag_to_db(dag) + DagVersion.get_latest_version(dag.dag_id) + + dr = dag_maker.create_dagrun() + + # Create asset event and attach to dag run + asset_event = AssetEvent( + asset_id=asset_model.id, + source_task_id="upstream_task", + source_dag_id="upstream_dag", + source_run_id="upstream_run", + source_map_index=-1, + ) + session.add(asset_event) + session.flush() + dr.consumed_asset_events.append(asset_event) + session.add(dr) + session.flush() + + executor = MockExecutor(do_update=False) + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + ti1 = dr.get_task_instance("dummy_task") + ti1.state = State.QUEUED + session.merge(ti1) + session.commit() + + executor.event_buffer[ti1.key] = State.FAILED, None + + # This should not raise DetachedInstanceError + self.job_runner._process_executor_events(executor=executor, session=session) + + ti1.refresh_from_db(session=session) + assert ti1.state == State.FAILED + + # Verify callback was created with asset event data + scheduler_job.executor.callback_sink.send.assert_called_once() + callback_request = scheduler_job.executor.callback_sink.send.call_args.args[0] + assert callback_request.context_from_server is not None + assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1 + assert callback_request.context_from_server.dag_run.consumed_asset_events[0].asset.uri == asset1.uri + def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker): dag_id = "SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute" task_id_1 = "dummy_task" @@ -628,6 +689,97 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) assert ti1.state == State.SCHEDULED session.rollback() + @pytest.mark.usefixtures("testing_dag_bundle") + def test_find_and_purge_task_instances_without_heartbeats_with_asset_events( + self, session, dag_maker, create_dagrun + ): + """ + Test that heartbeat purge succeeds when DagRun has consumed_asset_events. + + Regression test for DetachedInstanceError when building TaskCallbackRequest + with asset event data after session expunge. + """ + asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group") + asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group) + session.add(asset_model) + session.flush() + + with dag_maker(dag_id="test_heartbeat_with_assets", schedule=[asset1]): + EmptyOperator(task_id="dummy_task") + + dag = dag_maker.dag + scheduler_dag = sync_dag_to_db(dag) + dag_v = DagVersion.get_latest_version(dag.dag_id) + + data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE) + dag_run = create_dagrun( + scheduler_dag, + logical_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + data_interval=data_interval, + ) + + # Create asset alias and event with full relationships + asset_alias = AssetAliasModel(name="test_alias", group="test_group") + session.add(asset_alias) + session.flush() + + asset_event = AssetEvent( + asset_id=asset_model.id, + source_task_id="upstream_task", + source_dag_id="upstream_dag", + source_run_id="upstream_run", + source_map_index=-1, + ) + session.add(asset_event) + session.flush() + + # Attach alias to event and event to dag run + asset_event.source_aliases.append(asset_alias) + dag_run.consumed_asset_events.append(asset_event) + session.add_all([asset_event, dag_run]) + session.flush() + + executor = MockExecutor() + scheduler_job = Job(executor=executor) + with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock: + loader_mock.return_value = executor + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + ti = dag_run.get_task_instance("dummy_task") + assert ti is not None # sanity check: dag_maker.create_dagrun created the TI + + ti.state = State.RUNNING + ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6) + ti.start_date = timezone.utcnow() - timedelta(minutes=10) + ti.queued_by_job_id = scheduler_job.id + ti.dag_version = dag_v + session.merge(ti) + session.flush() + + executor.running.add(ti.key) + + tis_without_heartbeats = self.job_runner._find_task_instances_without_heartbeats(session=session) + assert len(tis_without_heartbeats) == 1 + ti_from_query = tis_without_heartbeats[0] + ti_key = ti_from_query.key + + # Detach all ORM objects to mirror scheduler behaviour after session closes + session.expunge_all() + + # This should not raise DetachedInstanceError now that eager loads are in place + self.job_runner._purge_task_instances_without_heartbeats(tis_without_heartbeats, session=session) + assert ti_key not in executor.running + + executor.callback_sink.send.assert_called_once() + callback_request = executor.callback_sink.send.call_args.args[0] + assert callback_request.context_from_server is not None + assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1 + consumed_event = callback_request.context_from_server.dag_run.consumed_asset_events[0] + assert consumed_event.asset.uri == asset1.uri + assert len(consumed_event.source_aliases) == 1 + assert consumed_event.source_aliases[0].name == "test_alias" + # @pytest.mark.usefixtures("mock_executor") def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): """