diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4557445151051..0763f59e3b180 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -93,7 +93,7 @@ from types import FrameType from pendulum.datetime import DateTime - from sqlalchemy.orm import Query, Session + from sqlalchemy.orm import Load, Query, Session from airflow._shared.logging.types import Logger from airflow.executors.base_executor import BaseExecutor @@ -110,6 +110,31 @@ """:meta private:""" +def _eager_load_dag_run_for_validation() -> tuple[Load, Load]: + """ + Eager-load DagRun relations required for execution API datamodel validation. + + When building TaskCallbackRequest with DRDataModel.model_validate(ti.dag_run), + the consumed_asset_events collection and nested asset/source_aliases must be + preloaded to avoid DetachedInstanceError after the session closes. + + Returns a tuple of two load options: + - Asset loader: TI.dag_run → consumed_asset_events → asset + - Alias loader: TI.dag_run → consumed_asset_events → source_aliases + + Example usage:: + + asset_loader, alias_loader = _eager_load_dag_run_for_validation() + query = select(TI).options(asset_loader).options(alias_loader) + """ + # Traverse TI → dag_run → consumed_asset_events once, then branch to asset/aliases + base = joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events) + return ( + base.selectinload(AssetEvent.asset), + base.selectinload(AssetEvent.source_aliases), + ) + + def _get_current_dag(dag_id: str, session: Session) -> SerializedDAG | None: serdag = SerializedDagModel.get(dag_id=dag_id, session=session) # grabs the latest version if not serdag: @@ -806,11 +831,12 @@ def process_executor_events( # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) + asset_loader, _ = _eager_load_dag_run_for_validation() query = ( select(TI) .where(filter_for_tis) .options(selectinload(TI.dag_model)) - .options(joinedload(TI.dag_run).selectinload(DagRun.consumed_asset_events)) + .options(asset_loader) .options(joinedload(TI.dag_run).selectinload(DagRun.created_dag_version)) .options(joinedload(TI.dag_version)) ) @@ -2375,10 +2401,12 @@ def _find_and_purge_task_instances_without_heartbeats(self) -> None: def _find_task_instances_without_heartbeats(self, *, session: Session) -> list[TI]: self.log.debug("Finding 'running' jobs without a recent heartbeat") limit_dttm = timezone.utcnow() - timedelta(seconds=self._task_instance_heartbeat_timeout_secs) + asset_loader, alias_loader = _eager_load_dag_run_for_validation() task_instances_without_heartbeats = session.scalars( select(TI) .options(selectinload(TI.dag_model)) - .options(selectinload(TI.dag_run).selectinload(DagRun.consumed_asset_events)) + .options(asset_loader) + .options(alias_loader) .options(selectinload(TI.dag_version)) .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql") .join(DM, TI.dag_id == DM.dag_id) diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/de/common.json b/airflow-core/src/airflow/ui/public/i18n/locales/de/common.json index fe14e8e0d6713..e4d46a928f37e 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/de/common.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/de/common.json @@ -78,6 +78,11 @@ "githubRepo": "GitHub Ablage", "restApiReference": "REST API Referenz" }, + "download": { + "download": "Herunterladen", + "hotkey": "d", + "tooltip": "Drücken Sie {{hotkey}}, um Protokolle herunterzuladen" + }, "duration": "Laufzeit", "endDate": "Enddatum", "error": { diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/de/components.json b/airflow-core/src/airflow/ui/public/i18n/locales/de/components.json index 3e4166ba90804..bad98315e82f8 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/de/components.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/de/components.json @@ -101,8 +101,8 @@ "location": "Zeile {{line}} in {{name}}" }, "reparseDag": "Dag neu parsen", - "sortedAscending": "aufsteigend sortier", - "sortedDescending": "absteigend sortier", + "sortedAscending": "aufsteigend sortiert", + "sortedDescending": "absteigend sortiert", "sortedUnsorted": "unsortiert", "taskTries": "Versuch des Tasks", "toggleCardView": "Kachelansicht anzeigen", diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/de/dags.json b/airflow-core/src/airflow/ui/public/i18n/locales/de/dags.json index a81349c6e17a6..8eea58538a456 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/de/dags.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/de/dags.json @@ -20,8 +20,7 @@ "all": "Alle", "paused": "Pausiert" }, - "runIdPatternFilter": "Dag Läufe suchen", - "triggeringUserNameFilter": "Suche Läufe ausgelöst von..." + "runIdPatternFilter": "Dag Läufe suchen" }, "ownerLink": "Besitzer Verlinkungen zu {{owner}}", "runAndTaskActions": { diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/de/hitl.json b/airflow-core/src/airflow/ui/public/i18n/locales/de/hitl.json index 56162417601cf..48212bcd9d40b 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/de/hitl.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/de/hitl.json @@ -1,5 +1,8 @@ { "filters": { + "body": "Nachricht", + "createdAtFrom": "Erstellt von ", + "createdAtTo": "Erstellt bis ", "response": { "all": "Alle", "pending": "Ausstehend", @@ -12,11 +15,13 @@ "requiredActionCount_other": "{{count}} offene Interaktionen", "requiredActionState": "Status der Interaktion", "response": { + "created": "Antwort erstellt um ", "error": "Senden der Antwort fehlgeschlagen", "optionsDescription": "Wählen Sie Ihre Optionen für diesen Task", "optionsLabel": "Optionen", "received": "Antwort empfangen um ", "respond": "Antworten", + "responded_by_user_name": "Beantwortet von (Benutzername)", "success": "{{taskId}} Interaktion erfolgreich", "title": "Erforderliche Interaktion - {{taskId}}" }, diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/pl/common.json b/airflow-core/src/airflow/ui/public/i18n/locales/pl/common.json index 04ca807223bc1..eaafe9c4b5a6b 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/pl/common.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/pl/common.json @@ -113,6 +113,8 @@ }, "filter": "Filtr", "filters": { + "durationFrom": "Czas trwania od", + "durationTo": "Czas trwania do", "logicalDateFrom": "Data logiczna od", "logicalDateTo": "Data logiczna do", "runAfterFrom": "Uruchom po (od)", diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/pl/dag.json b/airflow-core/src/airflow/ui/public/i18n/locales/pl/dag.json index 27276f595f68d..3a55fced65990 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/pl/dag.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/pl/dag.json @@ -10,6 +10,7 @@ "hourly": "Godzinowo", "legend": { "less": "Mniej", + "mixed": "Mieszane", "more": "Więcej" }, "navigation": { @@ -19,6 +20,7 @@ "previousYear": "Poprzedni rok" }, "noData": "Brak danych", + "noFailedRuns": "Brak nieudanych wykonań", "noRuns": "Brak wykonań", "totalRuns": "Łączna liczba wykonań", "week": "Tydzień {{weekNumber}}", diff --git a/airflow-core/src/airflow/ui/public/i18n/locales/pl/dags.json b/airflow-core/src/airflow/ui/public/i18n/locales/pl/dags.json index 61b1f1fe712c4..76dfb64f269f4 100644 --- a/airflow-core/src/airflow/ui/public/i18n/locales/pl/dags.json +++ b/airflow-core/src/airflow/ui/public/i18n/locales/pl/dags.json @@ -20,7 +20,8 @@ "all": "Wszystkie", "paused": "Wstrzymane" }, - "runIdPatternFilter": "Szukaj Wykonań Dagów" + "runIdPatternFilter": "Szukaj Wykonań Dagów", + "triggeringUserNameFilter": "Szukaj według użytkownika wyzwalającego" }, "ownerLink": "Link do właściciela {{owner}}", "runAndTaskActions": { diff --git a/airflow-core/src/airflow/ui/src/layouts/Details/PanelButtons.tsx b/airflow-core/src/airflow/ui/src/layouts/Details/PanelButtons.tsx index b43895ced2c91..6a766fb1f0ddd 100644 --- a/airflow-core/src/airflow/ui/src/layouts/Details/PanelButtons.tsx +++ b/airflow-core/src/airflow/ui/src/layouts/Details/PanelButtons.tsx @@ -208,11 +208,13 @@ export const PanelButtons = ({ ); return ( - - + + { setDagView("grid"); @@ -221,12 +223,13 @@ export const PanelButtons = ({ } }} title={translate("dag:panel.buttons.showGridShortcut")} - variant={dagView === "grid" ? "solid" : "outline"} > { setDagView("graph"); @@ -235,17 +238,16 @@ export const PanelButtons = ({ } }} title={translate("dag:panel.buttons.showGraphShortcut")} - variant={dagView === "graph" ? "solid" : "outline"} > - + {/* eslint-disable-next-line jsx-a11y/no-autofocus */} - diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/DagsList.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/DagsList.tsx index ef3153876c37f..f0ea76cf87cef 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/DagsList.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/DagsList.tsx @@ -164,6 +164,7 @@ const createColumns = ( header: "", }, { + accessorKey: "favourite", cell: ({ row: { original } }) => ( ), diff --git a/airflow-core/src/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx b/airflow-core/src/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx index f89764e0b8202..3559d48065bcc 100644 --- a/airflow-core/src/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx +++ b/airflow-core/src/airflow/ui/src/pages/TaskInstance/ExtraLinks.tsx @@ -23,7 +23,7 @@ import { useParams } from "react-router-dom"; import { useTaskInstanceServiceGetExtraLinks } from "openapi/queries"; export const ExtraLinks = () => { - const { t: translate } = useTranslation(); + const { t: translate } = useTranslation("dag"); const { dagId = "", mapIndex = "-1", runId = "", taskId = "" } = useParams(); const { data } = useTaskInstanceServiceGetExtraLinks({ @@ -35,7 +35,7 @@ export const ExtraLinks = () => { return data && Object.keys(data.extra_links).length > 0 ? ( - {translate("dag.extraLinks")} + {translate("extraLinks")} {Object.entries(data.extra_links).map(([key, value], _) => value === null ? undefined : ( diff --git a/airflow-core/src/airflow/ui/src/queries/useToggleFavoriteDag.ts b/airflow-core/src/airflow/ui/src/queries/useToggleFavoriteDag.ts index d3cfaa7cf0c2d..6093177a64d59 100644 --- a/airflow-core/src/airflow/ui/src/queries/useToggleFavoriteDag.ts +++ b/airflow-core/src/airflow/ui/src/queries/useToggleFavoriteDag.ts @@ -35,10 +35,13 @@ export const useToggleFavoriteDag = (dagId: string) => { queryKey: [useDagServiceGetDagsUiKey, UseDagServiceGetDagDetailsKeyFn({ dagId }, [{ dagId }])], }); - // Invalidate the specific DAG details query for this DAG - await queryClient.invalidateQueries({ - queryKey: UseDagServiceGetDagDetailsKeyFn({ dagId }, [{ dagId }]), - }); + const queryKeys = [ + // Invalidate the specific DAG details query for this DAG and DAGs list query. + UseDagServiceGetDagDetailsKeyFn({ dagId }, [{ dagId }]), + [useDagServiceGetDagsUiKey], + ]; + + await Promise.all(queryKeys.map((key) => queryClient.invalidateQueries({ queryKey: key }))); }, [queryClient, dagId]); const favoriteMutation = useDagServiceFavoriteDag({ diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index cfebf7b733f9b..49e82fd6587e5 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -608,6 +608,67 @@ def test_process_executor_events_ti_requeued(self, mock_stats_incr, mock_task_ca scheduler_job.executor.callback_sink.send.assert_not_called() mock_stats_incr.assert_not_called() + @pytest.mark.usefixtures("testing_dag_bundle") + @mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr") + def test_process_executor_events_with_asset_events(self, mock_stats_incr, session, dag_maker): + """ + Test that _process_executor_events handles asset events without DetachedInstanceError. + + Regression test for scheduler crashes when task callbacks are built with + consumed_asset_events that weren't eager-loaded. + """ + asset1 = Asset(uri="test://asset1", name="test_asset_executor", group="test_group") + asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group) + session.add(asset_model) + session.flush() + + with dag_maker(dag_id="test_executor_events_with_assets", schedule=[asset1], fileloc="/test_path1/"): + EmptyOperator(task_id="dummy_task", on_failure_callback=lambda ctx: None) + + dag = dag_maker.dag + sync_dag_to_db(dag) + DagVersion.get_latest_version(dag.dag_id) + + dr = dag_maker.create_dagrun() + + # Create asset event and attach to dag run + asset_event = AssetEvent( + asset_id=asset_model.id, + source_task_id="upstream_task", + source_dag_id="upstream_dag", + source_run_id="upstream_run", + source_map_index=-1, + ) + session.add(asset_event) + session.flush() + dr.consumed_asset_events.append(asset_event) + session.add(dr) + session.flush() + + executor = MockExecutor(do_update=False) + scheduler_job = Job(executor=executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + ti1 = dr.get_task_instance("dummy_task") + ti1.state = State.QUEUED + session.merge(ti1) + session.commit() + + executor.event_buffer[ti1.key] = State.FAILED, None + + # This should not raise DetachedInstanceError + self.job_runner._process_executor_events(executor=executor, session=session) + + ti1.refresh_from_db(session=session) + assert ti1.state == State.FAILED + + # Verify callback was created with asset event data + scheduler_job.executor.callback_sink.send.assert_called_once() + callback_request = scheduler_job.executor.callback_sink.send.call_args.args[0] + assert callback_request.context_from_server is not None + assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1 + assert callback_request.context_from_server.dag_run.consumed_asset_events[0].asset.uri == asset1.uri + def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker): dag_id = "SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute" task_id_1 = "dummy_task" @@ -628,6 +689,97 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) assert ti1.state == State.SCHEDULED session.rollback() + @pytest.mark.usefixtures("testing_dag_bundle") + def test_find_and_purge_task_instances_without_heartbeats_with_asset_events( + self, session, dag_maker, create_dagrun + ): + """ + Test that heartbeat purge succeeds when DagRun has consumed_asset_events. + + Regression test for DetachedInstanceError when building TaskCallbackRequest + with asset event data after session expunge. + """ + asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group") + asset_model = AssetModel(name=asset1.name, uri=asset1.uri, group=asset1.group) + session.add(asset_model) + session.flush() + + with dag_maker(dag_id="test_heartbeat_with_assets", schedule=[asset1]): + EmptyOperator(task_id="dummy_task") + + dag = dag_maker.dag + scheduler_dag = sync_dag_to_db(dag) + dag_v = DagVersion.get_latest_version(dag.dag_id) + + data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE) + dag_run = create_dagrun( + scheduler_dag, + logical_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + data_interval=data_interval, + ) + + # Create asset alias and event with full relationships + asset_alias = AssetAliasModel(name="test_alias", group="test_group") + session.add(asset_alias) + session.flush() + + asset_event = AssetEvent( + asset_id=asset_model.id, + source_task_id="upstream_task", + source_dag_id="upstream_dag", + source_run_id="upstream_run", + source_map_index=-1, + ) + session.add(asset_event) + session.flush() + + # Attach alias to event and event to dag run + asset_event.source_aliases.append(asset_alias) + dag_run.consumed_asset_events.append(asset_event) + session.add_all([asset_event, dag_run]) + session.flush() + + executor = MockExecutor() + scheduler_job = Job(executor=executor) + with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock: + loader_mock.return_value = executor + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + ti = dag_run.get_task_instance("dummy_task") + assert ti is not None # sanity check: dag_maker.create_dagrun created the TI + + ti.state = State.RUNNING + ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6) + ti.start_date = timezone.utcnow() - timedelta(minutes=10) + ti.queued_by_job_id = scheduler_job.id + ti.dag_version = dag_v + session.merge(ti) + session.flush() + + executor.running.add(ti.key) + + tis_without_heartbeats = self.job_runner._find_task_instances_without_heartbeats(session=session) + assert len(tis_without_heartbeats) == 1 + ti_from_query = tis_without_heartbeats[0] + ti_key = ti_from_query.key + + # Detach all ORM objects to mirror scheduler behaviour after session closes + session.expunge_all() + + # This should not raise DetachedInstanceError now that eager loads are in place + self.job_runner._purge_task_instances_without_heartbeats(tis_without_heartbeats, session=session) + assert ti_key not in executor.running + + executor.callback_sink.send.assert_called_once() + callback_request = executor.callback_sink.send.call_args.args[0] + assert callback_request.context_from_server is not None + assert len(callback_request.context_from_server.dag_run.consumed_asset_events) == 1 + consumed_event = callback_request.context_from_server.dag_run.consumed_asset_events[0] + assert consumed_event.asset.uri == asset1.uri + assert len(consumed_event.source_aliases) == 1 + assert consumed_event.source_aliases[0].name == "test_alias" + # @pytest.mark.usefixtures("mock_executor") def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): """ diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py b/scripts/ci/prek/check_template_context_variable_in_sync.py index 0b74e4beedbd8..1c55fbd19208e 100755 --- a/scripts/ci/prek/check_template_context_variable_in_sync.py +++ b/scripts/ci/prek/check_template_context_variable_in_sync.py @@ -83,17 +83,25 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]: yield key.value # Extract keys from the main `context` dictionary assignment - context_assignment = next( + context_assignment: ast.AnnAssign = next( stmt for stmt in fn_get_template_context.body if isinstance(stmt, ast.AnnAssign) - and isinstance(stmt.target, ast.Name) - and stmt.target.id == "context" + and isinstance(stmt.target, ast.Attribute) + and isinstance(stmt.target.value, ast.Name) + and stmt.target.value.id == "self" + and stmt.target.attr == "_context" ) - if not isinstance(context_assignment.value, ast.Dict): + if not isinstance(context_assignment.value, ast.BoolOp): + raise TypeError("Expected a BoolOp like 'self._context or {...}'.") + + context_assignment_op = context_assignment.value + _, context_assignment_value = context_assignment_op.values + + if not isinstance(context_assignment_value, ast.Dict): raise ValueError("'context' is not assigned a dictionary literal") - yield from extract_keys_from_dict(context_assignment.value) + yield from extract_keys_from_dict(context_assignment_value) # Handle keys added conditionally in `if from_server` for stmt in fn_get_template_context.body: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 409982d1a6bf9..7138598027126 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -131,6 +131,9 @@ class RuntimeTaskInstance(TaskInstance): task: BaseOperator bundle_instance: BaseDagBundle + _context: Context | None = None + """The Task Instance context.""" + _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None """The Task Instance context from the API server, if any.""" @@ -173,7 +176,9 @@ def get_template_context(self) -> Context: validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) - context: Context = { + # Cache the context object, which ensures that all calls to get_template_context + # are operating on the same context object. + self._context: Context = self._context or { # From the Task Execution interface "dag": self.task.dag, "inlets": self.task.inlets, @@ -213,7 +218,7 @@ def get_template_context(self) -> Context: lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) ), } - context.update(context_from_server) + self._context.update(context_from_server) if logical_date := coerce_datetime(dag_run.logical_date): if TYPE_CHECKING: @@ -224,7 +229,7 @@ def get_template_context(self) -> Context: ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") ts_nodash_with_tz = ts.replace("-", "").replace(":", "") # logical_date and data_interval either coexist or be None together - context.update( + self._context.update( { # keys that depend on logical_date "logical_date": logical_date, @@ -251,7 +256,7 @@ def get_template_context(self) -> Context: # existence. Should this be a private attribute on RuntimeTI instead perhaps? setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) - return context + return self._context def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 930d80590ea7b..a2a19266ae5aa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -65,7 +65,7 @@ ) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet -from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model from airflow.sdk.definitions.param import DagParam from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -2482,6 +2482,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error): def before_stopping(self, component): self.component = component + class CustomOutletEventsListener: + def __init__(self): + self.outlet_events = [] + self.error = None + + def _add_outlet_events(self, context): + outlets = context["outlets"] + for outlet in outlets: + self.outlet_events.append(context["outlet_events"][outlet]) + + @hookimpl + def on_task_instance_running(self, previous_state, task_instance): + context = task_instance.get_template_context() + self._add_outlet_events(context) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance): + context = task_instance.get_template_context() + self._add_outlet_events(context) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance, error): + context = task_instance.get_template_context() + self._add_outlet_events(context) + self.error = error + @pytest.fixture(autouse=True) def clean_listener_manager(self): lm = get_listener_manager() @@ -2601,6 +2627,118 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert listener.error == error + def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms): + """Test listener can access outlet events through invoking get_template_context() while task running and success""" + listener = self.CustomOutletEventsListener() + get_listener_manager().add_listener(listener) + + test_asset = Asset("test-asset") + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} + + class Producer(BaseOperator): + def execute(self, context): + outlet_events = context["outlet_events"] + outlet_events[test_asset].extra = test_extra + + task = Producer( + task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset] + ) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct( + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() + ) + + log = mock.MagicMock() + context = runtime_ti.get_template_context() + + with mock.patch( + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" + ) as validate_mock: + state, _, _ = run(runtime_ti, context, log) + + validate_mock.assert_called_once() + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + finalize(runtime_ti, state, context, log) + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + @pytest.mark.parametrize( + "exception", + [ + ValueError("oops"), + SystemExit("oops"), + AirflowException("oops"), + ], + ids=["ValueError", "SystemExit", "AirflowException"], + ) + def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception): + """Test listener can access outlet events through invoking get_template_context() while task failed""" + listener = self.CustomOutletEventsListener() + get_listener_manager().add_listener(listener) + + test_asset = Asset("test-asset") + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} + + class Producer(BaseOperator): + def execute(self, context): + outlet_events = context["outlet_events"] + outlet_events[test_asset].extra = test_extra + raise exception + + task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset]) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct( + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() + ) + + log = mock.MagicMock() + context = runtime_ti.get_template_context() + + with mock.patch( + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" + ) as validate_mock: + state, _, error = run(runtime_ti, context, log) + + validate_mock.assert_called_once() + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + finalize(runtime_ti, state, context, log, error) + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + assert listener.error == error + @pytest.mark.usefixtures("mock_supervisor_comms") class TestTaskRunnerCallsCallbacks: