From 6188232432d81fcbf3bc5c741fa045865b4a88b4 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 28 Aug 2025 18:23:30 +0100 Subject: [PATCH 1/2] Revert "Fix rendering of template fields with start from trigger" This reverts commit 4362c25e5cb96e3b6c99497b682d4dfe23a86aba from https://github.com/apache/airflow/pull/53071. The original change violated Airflow's architectural separation by making the Triggerer supervisor process load DAGs and attempt template rendering. This breaks with mapped operators that use lazy sequences requiring SUPERVISOR_COMMS, which only exists in task execution context. The Triggerer should operate purely on trigger database records and pre-computed trigger kwargs, not attempt runtime DAG operations. Fixes triggerer failures with mapped operators using XCom dependencies. --- .../src/airflow/jobs/triggerer_job_runner.py | 61 ++++++----------- .../serialization/serialized_objects.py | 3 +- airflow-core/src/airflow/triggers/base.py | 23 +++---- .../tests/unit/jobs/test_triggerer_job.py | 68 +++---------------- airflow-core/tests/unit/models/test_dagrun.py | 2 +- .../serialization/test_dag_serialization.py | 2 +- task-sdk/docs/api.rst | 2 - task-sdk/src/airflow/sdk/__init__.py | 3 - task-sdk/src/airflow/sdk/__init__.pyi | 2 - task-sdk/src/airflow/sdk/bases/operator.py | 6 +- task-sdk/src/airflow/sdk/bases/trigger.py | 33 --------- .../definitions/_internal/abstractoperator.py | 12 ---- .../airflow/sdk/definitions/mappedoperator.py | 14 +--- 13 files changed, 51 insertions(+), 180 deletions(-) delete mode 100644 task-sdk/src/airflow/sdk/bases/trigger.py diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index d4d49e50f9f51..ba85d00f18ae5 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -43,7 +43,6 @@ from airflow.executors import workloads from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat -from airflow.models import DagBag from airflow.models.trigger import Trigger from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( @@ -606,45 +605,6 @@ def update_triggers(self, requested_trigger_ids: set[int]): trigger set. """ render_log_fname = log_filename_template_renderer() - dag_bag = DagBag(collect_dags=False) - - def expand_start_trigger_args(trigger: Trigger) -> Trigger: - task = dag_bag.get_dag(trigger.task_instance.dag_id).get_task(trigger.task_instance.task_id) - if task.template_fields: - trigger.task_instance.refresh_from_task(task) - context = trigger.task_instance.get_template_context() - task.render_template_fields(context=context) - start_trigger_args = task.expand_start_trigger_args(context=context) - if start_trigger_args: - trigger.kwargs = start_trigger_args.trigger_kwargs - return trigger - - def create_workload(trigger: Trigger) -> workloads.RunTrigger: - if trigger.task_instance: - log_path = render_log_fname(ti=trigger.task_instance) - - trigger = expand_start_trigger_args(trigger) - - ser_ti = workloads.TaskInstance.model_validate(trigger.task_instance, from_attributes=True) - # When producing logs from TIs, include the job id producing the logs to disambiguate it. - self.logger_cache[new_id] = TriggerLoggingFactory( - log_path=f"{log_path}.trigger.{self.job.id}.log", - ti=ser_ti, # type: ignore - ) - - return workloads.RunTrigger( - classpath=trigger.classpath, - id=new_id, - encrypted_kwargs=trigger.encrypted_kwargs, - ti=ser_ti, - timeout_after=trigger.task_instance.trigger_timeout, - ) - return workloads.RunTrigger( - classpath=trigger.classpath, - id=new_id, - encrypted_kwargs=trigger.encrypted_kwargs, - ti=None, - ) known_trigger_ids = ( self.running_triggers.union(x[0] for x in self.events) @@ -682,7 +642,26 @@ def create_workload(trigger: Trigger) -> workloads.RunTrigger: ) continue - workload = create_workload(new_trigger_orm) + workload = workloads.RunTrigger( + classpath=new_trigger_orm.classpath, + id=new_id, + encrypted_kwargs=new_trigger_orm.encrypted_kwargs, + ti=None, + ) + if new_trigger_orm.task_instance: + log_path = render_log_fname(ti=new_trigger_orm.task_instance) + + ser_ti = workloads.TaskInstance.model_validate( + new_trigger_orm.task_instance, from_attributes=True + ) + # When producing logs from TIs, include the job id producing the logs to disambiguate it. + self.logger_cache[new_id] = TriggerLoggingFactory( + log_path=f"{log_path}.trigger.{self.job.id}.log", + ti=ser_ti, # type: ignore + ) + + workload.ti = ser_ti + workload.timeout_after = new_trigger_orm.task_instance.trigger_timeout to_create.append(workload) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 6cbf73897038c..aa76448f0cae1 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -52,7 +52,6 @@ from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.sdk import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? -from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.asset import ( @@ -84,7 +83,7 @@ from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep -from airflow.triggers.base import BaseTrigger +from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ( ConnectionAccessor, diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index dcb43be9add80..490423da5fda2 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -19,6 +19,8 @@ import abc import json from collections.abc import AsyncIterator +from dataclasses import dataclass +from datetime import timedelta from typing import Annotated, Any import structlog @@ -31,25 +33,20 @@ ) from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.module_loading import import_string from airflow.utils.state import TaskInstanceState log = structlog.get_logger(logger_name=__name__) -def __getattr__(name: str): - if name == "StartTriggerArgs": - import warnings +@dataclass +class StartTriggerArgs: + """Arguments required for start task execution from triggerer.""" - warnings.warn( - "airflow.triggers.base.StartTriggerArgs is deprecated. " - "Use airflow.sdk.bases.trigger.StartTriggerArgs instead.", - DeprecationWarning, - stacklevel=2, - ) - return import_string(f"airflow.sdk.bases.trigger.{name}") - - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: timedelta | None = None class BaseTrigger(abc.ABC, LoggingMixin): diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 2ffa92b1997d4..2c5cc4dd50f24 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -42,7 +42,7 @@ TriggerRunnerSupervisor, messages, ) -from airflow.models import DagBag, DagModel, DagRun, TaskInstance, Trigger +from airflow.models import DagModel, DagRun, TaskInstance, Trigger from airflow.models.connection import Connection from airflow.models.dag import DAG from airflow.models.dag_version import DagVersion @@ -128,15 +128,6 @@ def create_trigger_in_db(session, trigger, operator=None): return dag_model, run, trigger_orm, task_instance -def mock_dag_bag(mock_dag_bag_cls, task_instance: TaskInstance): - mock_dag = MagicMock(spec=DAG) - mock_dag.get_task.return_value = task_instance.task - - mock_dag_bag = MagicMock(spec=DagBag) - mock_dag_bag.get_dag.return_value = mock_dag - mock_dag_bag_cls.return_value = mock_dag_bag - - def test_is_needed(session): """Checks the triggerer-is-needed logic""" # No triggers, no need @@ -215,8 +206,7 @@ def builder(job=None): return builder -@patch("airflow.jobs.triggerer_job_runner.DagBag") -def test_trigger_lifecycle(mock_dag_bag_cls, spy_agency: SpyAgency, session, testing_dag_bundle): +def test_trigger_lifecycle(spy_agency: SpyAgency, session, testing_dag_bundle): """ Checks that the triggerer will correctly see a new Trigger in the database and send it to the trigger runner, and then delete it when it vanishes. @@ -225,8 +215,6 @@ def test_trigger_lifecycle(mock_dag_bag_cls, spy_agency: SpyAgency, session, tes # (we want to avoid it firing and deleting itself) trigger = TimeDeltaTrigger(datetime.timedelta(days=7)) dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger) - mock_dag_bag(mock_dag_bag_cls, task_instance) - # Make a TriggererJobRunner and have it retrieve DB tasks trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(id=12345), capacity=10) @@ -409,10 +397,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session): @pytest.mark.asyncio -@patch("airflow.jobs.triggerer_job_runner.DagBag") -async def test_trigger_create_race_condition_38599( - mock_dag_bag_cls, session, supervisor_builder, testing_dag_bundle -): +async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle): """ This verifies the resolution of race condition documented in github issue #38599. More details in the issue description. @@ -441,14 +426,10 @@ async def test_trigger_create_race_condition_38599( dm = DagModel(dag_id="test-dag", bundle_name=bundle_name) session.add(dm) SerializedDagModel.write_dag(dag, bundle_name=bundle_name) - dag_run = DagRun( - dag.dag_id, run_id="abc", run_type="manual", start_date=timezone.utcnow(), run_after=timezone.utcnow() - ) + dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow()) dag_version = DagVersion.get_latest_version(dag.dag_id) - task = PythonOperator(task_id="dummy-task", python_callable=print) - task.dag = dag ti = TaskInstance( - task, + PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id, state=TaskInstanceState.DEFERRED, dag_version_id=dag_version.id, @@ -465,8 +446,6 @@ async def test_trigger_create_race_condition_38599( session.commit() - mock_dag_bag(mock_dag_bag_cls, ti) - supervisor1 = supervisor_builder(job1) supervisor2 = supervisor_builder(job2) @@ -600,8 +579,7 @@ async def test_trigger_failing(): info["task"].cancel() -@patch("airflow.jobs.triggerer_job_runner.DagBag") -def test_failed_trigger(mock_dag_bag_cls, session, dag_maker, supervisor_builder): +def test_failed_trigger(session, dag_maker, supervisor_builder): """ Checks that the triggerer will correctly fail task instances that depend on triggers that can't even be loaded. @@ -624,8 +602,6 @@ def test_failed_trigger(mock_dag_bag_cls, session, dag_maker, supervisor_builder task_instance.trigger_id = trigger_orm.id session.commit() - mock_dag_bag(mock_dag_bag_cls, task_instance) - supervisor: TriggerRunnerSupervisor = supervisor_builder() supervisor.load_triggers() @@ -771,8 +747,7 @@ def handle_events(self): @pytest.mark.asyncio @pytest.mark.execution_timeout(20) -@patch("airflow.jobs.triggerer_job_runner.DagBag") -async def test_trigger_can_call_variables_connections_and_xcoms_methods(mock_dag_bag_cls, session, dag_maker): +async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker): """Checks that the trigger will successfully call Variables, Connections and XComs methods.""" # Create the test DAG and task with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session): @@ -834,8 +809,6 @@ async def test_trigger_can_call_variables_connections_and_xcoms_methods(mock_dag session.add(job) session.commit() - mock_dag_bag(mock_dag_bag_cls, task_instance) - supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None) supervisor.run() @@ -906,10 +879,7 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: @pytest.mark.asyncio @pytest.mark.execution_timeout(10) -@patch("airflow.jobs.triggerer_job_runner.DagBag") -async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable( - mock_dag_bag_cls, session, dag_maker -): +async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of trigger DAG runs.""" # Create the test DAG and task with dag_maker(dag_id="trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable", session=session): @@ -940,8 +910,6 @@ async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable( session.add(job) session.commit() - mock_dag_bag(mock_dag_bag_cls, task_instance) - supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None) supervisor.run() @@ -1002,8 +970,7 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]: @pytest.mark.asyncio @pytest.mark.execution_timeout(10) -@patch("airflow.jobs.triggerer_job_runner.DagBag") -async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_bag_cls, session, dag_maker): +async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker): """Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states.""" # Create the test DAG and task with dag_maker(dag_id="parent_dag", session=session): @@ -1044,8 +1011,6 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_b session.add(job) session.commit() - mock_dag_bag(mock_dag_bag_cls, task_instance) - supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None) supervisor.run() @@ -1058,10 +1023,7 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_b } -@patch("airflow.jobs.triggerer_job_runner.DagBag") -def test_update_triggers_prevents_duplicate_creation_queue_entries( - mock_dag_bag_cls, session, supervisor_builder -): +def test_update_triggers_prevents_duplicate_creation_queue_entries(session, supervisor_builder): """ Test that update_triggers prevents adding triggers to the creation queue if they are already queued for creation. @@ -1069,8 +1031,6 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries( trigger = TimeDeltaTrigger(datetime.timedelta(days=7)) dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger) - mock_dag_bag(mock_dag_bag_cls, task_instance) - supervisor = supervisor_builder() # First call to update_triggers should add the trigger to creating_triggers @@ -1092,9 +1052,8 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries( assert not any(trigger_id == trigger_orm.id for trigger_id, _ in supervisor.failed_triggers) -@patch("airflow.jobs.triggerer_job_runner.DagBag") def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple_triggers( - mock_dag_bag_cls, session, supervisor_builder, dag_maker + session, supervisor_builder, dag_maker ): """ Test that update_triggers prevents adding multiple triggers to the creation queue @@ -1105,8 +1064,6 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple dag_model1, run1, trigger_orm1, task_instance1 = create_trigger_in_db(session, trigger1) - mock_dag_bag(mock_dag_bag_cls, task_instance1) - with dag_maker("test_dag_2"): EmptyOperator(task_id="test_ti_2") @@ -1115,9 +1072,6 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple ti2 = run2.task_instances[0] session.add(trigger_orm2) session.flush() - - mock_dag_bag(mock_dag_bag_cls, ti2) - ti2.trigger_id = trigger_orm2.id session.merge(ti2) session.flush() diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 9dfe096b19385..9c85dc041e618 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -44,11 +44,11 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator from airflow.sdk import BaseOperator, setup, task, task_group, teardown -from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats from airflow.task.trigger_rule import TriggerRule +from airflow.triggers.base import StartTriggerArgs from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index ed3fc7306a952..7384d09268606 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -65,7 +65,6 @@ from airflow.sdk import AssetAlias, BaseHook, teardown from airflow.sdk.bases.decorator import DecoratedOperator from airflow.sdk.bases.operator import BaseOperator -from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions.asset import Asset, AssetUniqueKey from airflow.sdk.definitions.operator_resources import Resources @@ -83,6 +82,7 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.timetables.simple import NullTimetable, OnceTimetable +from airflow.triggers.base import StartTriggerArgs from airflow.utils.module_loading import qualname from tests_common.test_utils.config import conf_vars diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index fc375ab48c5dd..40ea8e17287a4 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -72,8 +72,6 @@ Bases .. autoapiclass:: airflow.sdk.BaseHook -.. autoapiclass:: airflow.sdk.StartTriggerArgs - Connections & Variables ----------------------- .. autoapiclass:: airflow.sdk.Connection diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 7c4cfacc22744..cc29ef4057bde 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -39,7 +39,6 @@ "ObjectStoragePath", "Param", "PokeReturnValue", - "StartTriggerArgs", "TaskGroup", "TriggerRule", "Variable", @@ -68,7 +67,6 @@ from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear, cross_downstream from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue - from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata @@ -106,7 +104,6 @@ "Param": ".definitions.param", "PokeReturnValue": ".bases.sensor", "SecretCache": ".execution_time.cache", - "StartTriggerArgs": ".bases.trigger", "TaskGroup": ".definitions.taskgroup", "TriggerRule": ".api.datamodels._generated", "Variable": ".definitions.variable", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 5fe59a9f19400..c1112d26c04ea 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -29,7 +29,6 @@ from airflow.sdk.bases.sensor import ( BaseSensorOperator as BaseSensorOperator, PokeReturnValue as PokeReturnValue, ) -from airflow.sdk.bases.trigger import StartTriggerArgs as StartTriggerArgs from airflow.sdk.definitions.asset import ( Asset as Asset, AssetAlias as AssetAlias, @@ -79,7 +78,6 @@ __all__ = [ "Param", "PokeReturnValue", "SecretCache", - "StartTriggerArgs", "TaskGroup", "TriggerRule", "Variable", diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 762e2118500a0..3ece4b9937288 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -89,7 +89,6 @@ def db_safe_priority(priority_weight: int) -> int: import jinja2 from airflow.sdk.bases.operatorlink import BaseOperatorLink - from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.operator_resources import Resources @@ -97,7 +96,7 @@ def db_safe_priority(priority_weight: int) -> int: from airflow.sdk.definitions.xcom_arg import XComArg from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy - from airflow.triggers.base import BaseTrigger + from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.typing_compat import Self TaskPreExecuteHook = Callable[[Context], None] @@ -946,6 +945,9 @@ def say_hello_world(**context): # Set to True for an operator instantiated by a mapped operator. __from_mapped: bool = False + start_trigger_args: StartTriggerArgs | None = None + start_from_trigger: bool = False + # base list which includes all the attrs that don't need deep copy. _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( "user_defined_macros", diff --git a/task-sdk/src/airflow/sdk/bases/trigger.py b/task-sdk/src/airflow/sdk/bases/trigger.py deleted file mode 100644 index 5429b4c6adec8..0000000000000 --- a/task-sdk/src/airflow/sdk/bases/trigger.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from datetime import timedelta -from typing import Any - - -@dataclass -class StartTriggerArgs: - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: timedelta | None = None diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index f7db152a5cd6c..7629c06f7954c 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -43,7 +43,6 @@ from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.bases.operatorlink import BaseOperatorLink - from airflow.sdk.bases.trigger import StartTriggerArgs from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup @@ -120,9 +119,6 @@ class AbstractOperator(Templater, DAGNode): is_setup: bool = False is_teardown: bool = False - start_trigger_args: StartTriggerArgs | None = None - start_from_trigger: bool = False - HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( ( "log", @@ -319,14 +315,6 @@ def _do_render_template_fields( else: setattr(parent, attr_name, rendered_content) - if ( - self.start_from_trigger - and self.start_trigger_args - and self.start_trigger_args.trigger_kwargs - ): - if attr_name in self.start_trigger_args.trigger_kwargs: - self.start_trigger_args.trigger_kwargs[attr_name] = rendered_content - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ Return mapped nodes that are direct dependencies of the current task. diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index a2e7e7bf11c74..febc922ba265c 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -64,19 +64,11 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) - from airflow.sdk import ( - DAG, - BaseOperator, - BaseOperatorLink, - Context, - StartTriggerArgs, - TaskGroup, - TriggerRule, - XComArg, - ) + from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg from airflow.sdk.definitions._internal.expandinput import ExpandInput from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import ParamsDict + from airflow.triggers.base import StartTriggerArgs ValidationSource = Literal["expand"] | Literal["partial"] @@ -825,7 +817,7 @@ def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | N This method is for allowing mapped operator to start execution from triggerer. """ - from airflow.sdk.bases.trigger import StartTriggerArgs + from airflow.triggers.base import StartTriggerArgs if not self.start_trigger_args: return None From abf5600c28e74e436b44876e494a2d72f991ca2e Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 28 Aug 2025 21:02:01 +0100 Subject: [PATCH 2/2] fixup! Revert "Fix rendering of template fields with start from trigger" --- .../src/airflow/models/mappedoperator.py | 3 ++- .../providers/standard/sensors/date_time.py | 27 +++++++++---------- .../providers/standard/sensors/filesystem.py | 27 +++++++++---------- .../providers/standard/sensors/time.py | 27 +++++++++---------- 4 files changed, 38 insertions(+), 46 deletions(-) diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index bda98c8b5a5cb..01f0c04205b70 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -59,11 +59,12 @@ from airflow.models import TaskInstance from airflow.models.dag import DAG as SchedulerDAG from airflow.models.expandinput import SchedulerExpandInput - from airflow.sdk import BaseOperatorLink, Context, StartTriggerArgs + from airflow.sdk import BaseOperatorLink, Context from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import ParamsDict from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs Operator: TypeAlias = "SerializedBaseOperator | MappedOperator" diff --git a/providers/standard/src/airflow/providers/standard/sensors/date_time.py b/providers/standard/src/airflow/providers/standard/sensors/date_time.py index fc6ff69c88ace..eae51eaacff2e 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/standard/src/airflow/providers/standard/sensors/date_time.py @@ -31,21 +31,18 @@ from airflow.utils import timezone # type: ignore[attr-defined,no-redef] try: - from airflow.sdk import StartTriggerArgs -except ImportError: # TODO: Remove this when min airflow version is 3.1.0 for standard provider - try: - from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] - except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider - - @dataclass - class StartTriggerArgs: # type: ignore[no-redef] - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: datetime.timedelta | None = None + from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] +except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider + + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/sensors/filesystem.py b/providers/standard/src/airflow/providers/standard/sensors/filesystem.py index 86e9e3133b029..d23906c3404a2 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/filesystem.py +++ b/providers/standard/src/airflow/providers/standard/sensors/filesystem.py @@ -32,21 +32,18 @@ from airflow.providers.standard.version_compat import BaseSensorOperator try: - from airflow.sdk import StartTriggerArgs -except ImportError: # TODO: Remove this when min airflow version is 3.1.0 for standard provider - try: - from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] - except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider - - @dataclass - class StartTriggerArgs: # type: ignore[no-redef] - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: datetime.timedelta | None = None + from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] +except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider + + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None if TYPE_CHECKING: diff --git a/providers/standard/src/airflow/providers/standard/sensors/time.py b/providers/standard/src/airflow/providers/standard/sensors/time.py index 988b4b85843f4..129ae22fbd882 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/time.py +++ b/providers/standard/src/airflow/providers/standard/sensors/time.py @@ -28,21 +28,18 @@ from airflow.providers.standard.version_compat import BaseSensorOperator try: - from airflow.sdk import StartTriggerArgs -except ImportError: # TODO: Remove this when min airflow version is 3.1.0 for standard provider - try: - from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] - except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider - - @dataclass - class StartTriggerArgs: # type: ignore[no-redef] - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - next_kwargs: dict[str, Any] | None = None - timeout: datetime.timedelta | None = None + from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef] +except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider + + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None try: