Skip to content

Commit

Permalink
Allowing tasks to start execution directly from triggerer without goi…
Browse files Browse the repository at this point in the history
…ng to worker (#38674)
  • Loading branch information
Lee-W authored Apr 29, 2024
1 parent cb73815 commit 6112745
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 7 deletions.
2 changes: 2 additions & 0 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
# task's expand() contribute to the op_kwargs operator argument, not
# the operator arguments themselves, and should expand against it.
expand_input_attr="op_kwargs_expand_input",
start_trigger=self.operator_class.start_trigger,
next_method=self.operator_class.next_method,
)
return XComArg(operator=operator)

Expand Down
2 changes: 2 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class AbstractOperator(Templater, DAGNode):
"node_id", # Duplicates task_id
"task_group", # Doesn't have a useful repr, no point showing in UI
"inherits_from_empty_operator", # impl detail
"start_trigger",
"next_method",
# For compatibility with TG, for operators these are just the current task, no point showing
"roots",
"leaves",
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ def say_hello_world(**context):
# Set to True for an operator instantiated by a mapped operator.
__from_mapped = False

start_trigger: BaseTrigger | None = None
next_method: str | None = None

def __init__(
self,
task_id: str,
Expand Down Expand Up @@ -1675,6 +1678,8 @@ def get_serialized_fields(cls):
"is_teardown",
"on_failure_fail_dagrun",
"map_index_template",
"start_trigger",
"next_method",
}
)
DagContext.pop_context_managed_dag()
Expand Down
22 changes: 18 additions & 4 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound
from airflow.listeners.listener import get_listener_manager
from airflow.models import Log
from airflow.models.abstractoperator import NotMapped
Expand Down Expand Up @@ -905,9 +905,11 @@ def recalculate(self) -> _UnfinishedStates:
self.run_id,
self.start_date,
self.end_date,
(self.end_date - self.start_date).total_seconds()
if self.start_date and self.end_date
else None,
(
(self.end_date - self.start_date).total_seconds()
if self.start_date and self.end_date
else None
),
self._state,
self.external_trigger,
self.run_type,
Expand Down Expand Up @@ -1537,6 +1539,18 @@ def schedule_tis(
and not ti.task.outlets
):
dummy_ti_ids.append((ti.task_id, ti.map_index))
elif (
ti.task.start_trigger is not None
and ti.task.next_method is not None
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
):
ti._try_number += 1
ti.defer_task(
defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
session=session,
)
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))

Expand Down
7 changes: 7 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import BaseTrigger
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -236,6 +237,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
# For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
expand_input_attr="expand_input",
start_trigger=self.operator_class.start_trigger,
next_method=self.operator_class.next_method,
)
return op

Expand Down Expand Up @@ -278,6 +281,8 @@ class MappedOperator(AbstractOperator):
_task_module: str
_task_type: str
_operator_name: str
start_trigger: BaseTrigger | None
next_method: str | None

dag: DAG | None
task_group: TaskGroup | None
Expand Down Expand Up @@ -306,6 +311,8 @@ class MappedOperator(AbstractOperator):
(
"parse_time_mapped_ti_count",
"operator_class",
"start_trigger",
"next_method",
)
)

Expand Down
30 changes: 27 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context, DatasetEventAccessor, DatasetEventAccessors
from airflow.utils.docs import get_docs_url
from airflow.utils.helpers import exactly_one
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
Expand Down Expand Up @@ -1016,6 +1017,7 @@ def serialize_operator(cls, op: BaseOperator | MappedOperator) -> dict[str, Any]
def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
"""Serialize operator into a JSON object."""
serialize_op = cls.serialize_to_json(op, cls._decorated_fields)

serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__)
serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
if op.operator_name != serialize_op["_task_type"]:
Expand All @@ -1024,6 +1026,12 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)
# Used to determine if an Operator is inherited from EmptyOperator
serialize_op["_is_empty"] = op.inherits_from_empty_operator

if exactly_one(op.start_trigger is not None, op.next_method is not None):
raise AirflowException("start_trigger and next_method should both be set.")

serialize_op["start_trigger"] = op.start_trigger.serialize() if op.start_trigger else None
serialize_op["next_method"] = op.next_method

if op.operator_extra_links:
serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
op.operator_extra_links.__get__(op)
Expand Down Expand Up @@ -1204,6 +1212,17 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))

# Deserialize start_trigger
serialized_start_trigger = encoded_op.get("start_trigger")
if serialized_start_trigger:
trigger_cls_name, trigger_kwargs = serialized_start_trigger
trigger_cls = import_string(trigger_cls_name)
start_trigger = trigger_cls(**trigger_kwargs)
setattr(op, "start_trigger", start_trigger)
else:
setattr(op, "start_trigger", None)
setattr(op, "next_method", encoded_op.get("next_method", None))

@staticmethod
def set_task_dag_references(task: Operator, dag: DAG) -> None:
"""Handle DAG references on an operator.
Expand Down Expand Up @@ -1241,6 +1260,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
operator_name = encoded_op["_operator_name"]
except KeyError:
operator_name = encoded_op["_task_type"]

op = MappedOperator(
operator_class=op_data,
expand_input=EXPAND_INPUT_EMPTY,
Expand All @@ -1264,6 +1284,8 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
end_date=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
start_trigger=None,
next_method=None,
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])
Expand Down Expand Up @@ -1686,9 +1708,11 @@ def set_ref(task: Operator) -> Operator:
return task

group.children = {
label: set_ref(task_dict[val])
if _type == DAT.OP
else cls.deserialize_task_group(val, group, task_dict, dag=dag)
label: (
set_ref(task_dict[val])
if _type == DAT.OP
else cls.deserialize_task_group(val, group, task_dict, dag=dag)
)
for label, (_type, val) in encoded_group["children"].items()
}
group.upstream_group_ids.update(cls.deserialize(encoded_group["upstream_group_ids"]))
Expand Down
54 changes: 54 additions & 0 deletions docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,60 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an

``execution_timeout`` on operators is determined from the *total runtime*, not individual executions between deferrals. This means that if ``execution_timeout`` is set, an operator can fail while it's deferred or while it's running after a deferral, even if it's only been resumed for a few seconds.

Triggering Deferral from Start
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger`` and ``_next_method`` to your deferrable operator.

* ``start_trigger``: An instance of a trigger you want to defer to. It will be serialized into the database.
* ``next_method``: The method name on your operator that you want Airflow to call when it resumes.


This is particularly useful when deferring is the only thing the ``execute`` method does. Here's a basic refinement of the previous example.

.. code-block:: python
from datetime import timedelta
from typing import Any
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.utils.context import Context
class WaitOneHourSensor(BaseSensorOperator):
start_trigger = TimeDeltaTrigger(timedelta(hours=1))
next_method = "execute_complete"
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
# We have no more work to do here. Mark as complete.
return
``start_trigger`` and ``next_method`` can also be set at the instance level for more flexible configuration.

.. warning::
Dynamic task mapping is not supported when ``start_trigger`` and ``next_method`` are assigned in instance level.

.. code-block:: python
from datetime import timedelta
from typing import Any
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.utils.context import Context
class WaitOneHourSensor(BaseSensorOperator):
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
super().__init__(*args, **kwargs)
self.start_trigger = TimeDeltaTrigger(timedelta(hours=1))
self.next_method = "execute_complete"
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
# We have no more work to do here. Mark as complete.
return
Writing Triggers
~~~~~~~~~~~~~~~~

Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from airflow.operators.python import ShortCircuitOperator
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.stats import Stats
from airflow.triggers.testing import SuccessTrigger
from airflow.utils import timezone
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
Expand Down Expand Up @@ -1986,6 +1987,33 @@ def test_schedule_tis_map_index(dag_maker, session):
assert ti2.state == TaskInstanceState.SUCCESS


def test_schedule_tis_start_trigger(dag_maker, session):
"""
Test that an operator with _start_trigger and _next_method set can be directly
deferred during scheduling.
"""
trigger = SuccessTrigger()

class TestOperator(BaseOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_trigger = trigger
self.next_method = "execute_complete"

def execute_complete(self):
pass

with dag_maker(session=session):
task = TestOperator(task_id="test_task")

dr: DagRun = dag_maker.create_dagrun()

ti = TI(task=task, run_id=dr.run_id, state=None)
assert ti.state is None
dr.schedule_tis((ti,), session=session)
assert ti.state == TaskInstanceState.DEFERRED


def test_mapped_expand_kwargs(dag_maker):
with dag_maker():

Expand Down
Loading

0 comments on commit 6112745

Please sign in to comment.