diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 7adaf8a44735f..2ae85a9c435b2 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -509,6 +509,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", + start_trigger=self.operator_class.start_trigger, + next_method=self.operator_class.next_method, ) return XComArg(operator=operator) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index fa59d6cfc948d..3320b10109e7b 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -122,6 +122,8 @@ class AbstractOperator(Templater, DAGNode): "node_id", # Duplicates task_id "task_group", # Doesn't have a useful repr, no point showing in UI "inherits_from_empty_operator", # impl detail + "start_trigger", + "next_method", # For compatibility with TG, for operators these are just the current task, no point showing "roots", "leaves", diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0fe9bdf38080b..ded7d2861efa2 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -818,6 +818,9 @@ def say_hello_world(**context): # Set to True for an operator instantiated by a mapped operator. __from_mapped = False + start_trigger: BaseTrigger | None = None + next_method: str | None = None + def __init__( self, task_id: str, @@ -1675,6 +1678,8 @@ def get_serialized_fields(cls): "is_teardown", "on_failure_fail_dagrun", "map_index_template", + "start_trigger", + "next_method", } ) DagContext.pop_context_managed_dag() diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 777280e9aa7f1..774f8d0983192 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -51,7 +51,7 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.configuration import conf as airflow_conf -from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound from airflow.listeners.listener import get_listener_manager from airflow.models import Log from airflow.models.abstractoperator import NotMapped @@ -905,9 +905,11 @@ def recalculate(self) -> _UnfinishedStates: self.run_id, self.start_date, self.end_date, - (self.end_date - self.start_date).total_seconds() - if self.start_date and self.end_date - else None, + ( + (self.end_date - self.start_date).total_seconds() + if self.start_date and self.end_date + else None + ), self._state, self.external_trigger, self.run_type, @@ -1537,6 +1539,18 @@ def schedule_tis( and not ti.task.outlets ): dummy_ti_ids.append((ti.task_id, ti.map_index)) + elif ( + ti.task.start_trigger is not None + and ti.task.next_method is not None + and not ti.task.on_execute_callback + and not ti.task.on_success_callback + and not ti.task.outlets + ): + ti._try_number += 1 + ti.defer_task( + defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method), + session=session, + ) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index edb19aada6065..26a231456ab98 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -81,6 +81,7 @@ from airflow.models.param import ParamsDict from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import BaseTrigger from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup @@ -236,6 +237,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", + start_trigger=self.operator_class.start_trigger, + next_method=self.operator_class.next_method, ) return op @@ -278,6 +281,8 @@ class MappedOperator(AbstractOperator): _task_module: str _task_type: str _operator_name: str + start_trigger: BaseTrigger | None + next_method: str | None dag: DAG | None task_group: TaskGroup | None @@ -306,6 +311,8 @@ class MappedOperator(AbstractOperator): ( "parse_time_mapped_ti_count", "operator_class", + "start_trigger", + "next_method", ) ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 0fbc07a11d79f..5d3088fa4c597 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -70,6 +70,7 @@ from airflow.utils.code_utils import get_python_source from airflow.utils.context import Context, DatasetEventAccessor, DatasetEventAccessors from airflow.utils.docs import get_docs_url +from airflow.utils.helpers import exactly_one from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup @@ -1016,6 +1017,7 @@ def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any] def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]: """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) + serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__) serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__) if op.operator_name != serialize_op["_task_type"]: @@ -1024,6 +1026,12 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) # Used to determine if an Operator is inherited from EmptyOperator serialize_op["_is_empty"] = op.inherits_from_empty_operator + if exactly_one(op.start_trigger is not None, op.next_method is not None): + raise AirflowException("start_trigger and next_method should both be set.") + + serialize_op["start_trigger"] = op.start_trigger.serialize() if op.start_trigger else None + serialize_op["next_method"] = op.next_method + if op.operator_extra_links: serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links( op.operator_extra_links.__get__(op) @@ -1204,6 +1212,17 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: # Used to determine if an Operator is inherited from EmptyOperator setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False))) + # Deserialize start_trigger + serialized_start_trigger = encoded_op.get("start_trigger") + if serialized_start_trigger: + trigger_cls_name, trigger_kwargs = serialized_start_trigger + trigger_cls = import_string(trigger_cls_name) + start_trigger = trigger_cls(**trigger_kwargs) + setattr(op, "start_trigger", start_trigger) + else: + setattr(op, "start_trigger", None) + setattr(op, "next_method", encoded_op.get("next_method", None)) + @staticmethod def set_task_dag_references(task: Operator, dag: DAG) -> None: """Handle DAG references on an operator. @@ -1241,6 +1260,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: operator_name = encoded_op["_operator_name"] except KeyError: operator_name = encoded_op["_task_type"] + op = MappedOperator( operator_class=op_data, expand_input=EXPAND_INPUT_EMPTY, @@ -1264,6 +1284,8 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: end_date=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], + start_trigger=None, + next_method=None, ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) @@ -1686,9 +1708,11 @@ def set_ref(task: Operator) -> Operator: return task group.children = { - label: set_ref(task_dict[val]) - if _type == DAT.OP - else cls.deserialize_task_group(val, group, task_dict, dag=dag) + label: ( + set_ref(task_dict[val]) + if _type == DAT.OP + else cls.deserialize_task_group(val, group, task_dict, dag=dag) + ) for label, (_type, val) in encoded_group["children"].items() } group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"])) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index b9e304673109f..6777f5a9d6caa 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -141,6 +141,60 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an ``execution_timeout`` on operators is determined from the *total runtime*, not individual executions between deferrals. This means that if ``execution_timeout`` is set, an operator can fail while it's deferred or while it's running after a deferral, even if it's only been resumed for a few seconds. +Triggering Deferral from Start +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger`` and ``_next_method`` to your deferrable operator. + +* ``start_trigger``: An instance of a trigger you want to defer to. It will be serialized into the database. +* ``next_method``: The method name on your operator that you want Airflow to call when it resumes. + + +This is particularly useful when deferring is the only thing the ``execute`` method does. Here's a basic refinement of the previous example. + +.. code-block:: python + + from datetime import timedelta + from typing import Any + + from airflow.sensors.base import BaseSensorOperator + from airflow.triggers.temporal import TimeDeltaTrigger + from airflow.utils.context import Context + + + class WaitOneHourSensor(BaseSensorOperator): + start_trigger = TimeDeltaTrigger(timedelta(hours=1)) + next_method = "execute_complete" + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + # We have no more work to do here. Mark as complete. + return + +``start_trigger`` and ``next_method`` can also be set at the instance level for more flexible configuration. + +.. warning:: + Dynamic task mapping is not supported when ``start_trigger`` and ``next_method`` are assigned in instance level. + +.. code-block:: python + + from datetime import timedelta + from typing import Any + + from airflow.sensors.base import BaseSensorOperator + from airflow.triggers.temporal import TimeDeltaTrigger + from airflow.utils.context import Context + + + class WaitOneHourSensor(BaseSensorOperator): + def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: + super().__init__(*args, **kwargs) + self.start_trigger = TimeDeltaTrigger(timedelta(hours=1)) + self.next_method = "execute_complete" + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + # We have no more work to do here. Mark as complete. + return + Writing Triggers ~~~~~~~~~~~~~~~~ diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 929cf9ae87e3b..8a10c28f1c4a5 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -40,6 +40,7 @@ from airflow.operators.python import ShortCircuitOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats +from airflow.triggers.testing import SuccessTrigger from airflow.utils import timezone from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule @@ -1986,6 +1987,33 @@ def test_schedule_tis_map_index(dag_maker, session): assert ti2.state == TaskInstanceState.SUCCESS +def test_schedule_tis_start_trigger(dag_maker, session): + """ + Test that an operator with _start_trigger and _next_method set can be directly + deferred during scheduling. + """ + trigger = SuccessTrigger() + + class TestOperator(BaseOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start_trigger = trigger + self.next_method = "execute_complete" + + def execute_complete(self): + pass + + with dag_maker(session=session): + task = TestOperator(task_id="test_task") + + dr: DagRun = dag_maker.create_dagrun() + + ti = TI(task=task, run_id=dr.run_id, state=None) + assert ti.state is None + dr.schedule_tis((ti,), session=session) + assert ti.state == TaskInstanceState.DEFERRED + + def test_mapped_expand_kwargs(dag_maker): with dag_maker(): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index ed63a565f9a5f..623b2e673102a 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -73,6 +73,7 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.simple import NullTimetable, OnceTimetable +from airflow.triggers.testing import SuccessTrigger from airflow.utils import timezone from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup @@ -195,6 +196,8 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "doc_md": "### Task Tutorial Documentation", "_log_config_logger_name": "airflow.task.operators", "weight_rule": "downstream", + "next_method": None, + "start_trigger": None, }, }, { @@ -222,6 +225,8 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "on_failure_fail_dagrun": False, "_log_config_logger_name": "airflow.task.operators", "weight_rule": "downstream", + "next_method": None, + "start_trigger": None, }, }, ], @@ -2133,6 +2138,64 @@ def execute(self, context: Context): ): SerializedDAG.to_dict(dag) + @pytest.mark.db_test + def test_start_trigger_and_next_method_in_serialized_dag(self): + """ + Test that when we provide start_trigger and next_method, the DAG can be correctly serialized. + """ + trigger = SuccessTrigger() + + class TestOperator(BaseOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start_trigger = trigger + self.next_method = "execute_complete" + + def execute_complete(self): + pass + + class Test2Operator(BaseOperator): + def __init__(self, *args, **kwargs): + self.start_trigger = trigger + self.next_method = "execute_complete" + super().__init__(*args, **kwargs) + + def execute_complete(self): + pass + + dag = DAG(dag_id="test_dag", start_date=datetime(2023, 11, 9)) + + with dag: + TestOperator(task_id="test_task_1") + Test2Operator(task_id="test_task_2") + + serialized_obj = SerializedDAG.to_dict(dag) + + for task in serialized_obj["dag"]["tasks"]: + assert task["__var"]["start_trigger"] == trigger.serialize() + assert task["__var"]["next_method"] == "execute_complete" + + @pytest.mark.db_test + def test_start_trigger_in_serialized_dag_but_no_next_method(self): + """ + Test that when we provide start_trigger without next_method, an AriflowException should be raised. + """ + + trigger = SuccessTrigger() + + class TestOperator(BaseOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start_trigger = trigger + + dag = DAG(dag_id="test_dag", start_date=datetime(2023, 11, 9)) + + with dag: + TestOperator(task_id="test_task") + + with pytest.raises(AirflowException, match="start_trigger and next_method should both be set."): + SerializedDAG.to_dict(dag) + def test_kubernetes_optional(): """Serialisation / deserialisation continues to work without kubernetes installed""" @@ -2182,6 +2245,8 @@ def test_operator_expand_serde(): "_is_mapped": True, "_task_module": "airflow.operators.bash", "_task_type": "BashOperator", + "start_trigger": None, + "next_method": None, "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", @@ -2213,6 +2278,8 @@ def test_operator_expand_serde(): assert op.operator_class == { "_task_type": "BashOperator", + "start_trigger": None, + "next_method": None, "downstream_task_ids": [], "task_id": "a", "template_ext": [".sh", ".bash"], @@ -2257,6 +2324,8 @@ def test_operator_expand_xcomarg_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", + "next_method": None, + "start_trigger": None, } op = BaseSerialization.deserialize(serialized) @@ -2312,6 +2381,8 @@ def test_operator_expand_kwargs_literal_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", + "next_method": None, + "start_trigger": None, } op = BaseSerialization.deserialize(serialized) @@ -2358,6 +2429,8 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", + "next_method": None, + "start_trigger": None, } op = BaseSerialization.deserialize(serialized) @@ -2474,6 +2547,8 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", + "next_method": None, + "start_trigger": None, } deserialized = BaseSerialization.deserialize(serialized) @@ -2538,6 +2613,8 @@ def x(arg1, arg2, arg3): "_task_module": "airflow.decorators.python", "_task_type": "_PythonDecoratedOperator", "_operator_name": "@task", + "next_method": None, + "start_trigger": None, "downstream_task_ids": [], "partial_kwargs": { "is_setup": False, @@ -2688,6 +2765,8 @@ def operator_extra_links(self): "_task_module": "tests.serialization.test_dag_serialization", "_is_empty": False, "_is_mapped": True, + "next_method": None, + "start_trigger": None, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]