Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deferrable TriggerDagRunOperator #30292

Merged
merged 18 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@
import time
from typing import TYPE_CHECKING, Sequence, cast

from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCom
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType

Expand All @@ -40,6 +44,8 @@


if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models.taskinstance import TaskInstanceKey


Expand Down Expand Up @@ -79,6 +85,8 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
"""

template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id", "execution_date", "conf")
Expand All @@ -98,6 +106,7 @@ def __init__(
poke_interval: int = 60,
allowed_states: list | None = None,
failed_states: list | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -109,6 +118,7 @@ def __init__(
self.poke_interval = poke_interval
self.allowed_states = allowed_states or [State.SUCCESS]
self.failed_states = failed_states or [State.FAILED]
self._defer = deferrable

if execution_date is not None and not isinstance(execution_date, (str, datetime.datetime)):
raise TypeError(
Expand All @@ -118,6 +128,7 @@ def __init__(
self.execution_date = execution_date

def execute(self, context: Context):

if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
elif isinstance(self.execution_date, str):
Expand All @@ -134,6 +145,7 @@ def execute(self, context: Context):
run_id = self.trigger_run_id
else:
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date)

try:
dag_run = trigger_dag(
dag_id=self.trigger_dag_id,
Expand Down Expand Up @@ -168,6 +180,18 @@ def execute(self, context: Context):
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)

if self.wait_for_completion:

# Kick off the deferral process
if self._defer:
self.defer(
trigger=DagStateTrigger(
dag_id=self.trigger_dag_id,
states=self.allowed_states + self.failed_states,
execution_dates=[parsed_execution_date],
poll_interval=self.poke_interval,
),
method_name="execute_complete",
)
# wait for dag to complete
while True:
self.log.info(
Expand All @@ -185,3 +209,32 @@ def execute(self, context: Context):
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state)
return

@provide_session
def execute_complete(self, context: Context, session: Session, **kwargs):
parsed_execution_date = context["execution_date"]

try:
dag_run = (
session.query(DagRun)
.filter(DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == parsed_execution_date)
.one()
)

except NoResultFound:
raise AirflowException(
f"No DAG run found for DAG {self.trigger_dag_id} and execution date {self.execution_date}"
)

state = dag_run.state

if state in self.failed_states:
raise AirflowException(f"{self.trigger_dag_id} failed with failed state {state}")
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state)
return

raise AirflowException(
f"{self.trigger_dag_id} return {state} which is not in {self.failed_states}"
f" or {self.allowed_states}"
)
88 changes: 88 additions & 0 deletions tests/operators/test_trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,91 @@ def test_trigger_dagrun_triggering_itself_with_execution_date(self):
)
with pytest.raises(DagRunAlreadyExists):
task.run(start_date=execution_date, end_date=execution_date)

def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
execution_date=execution_date,
wait_for_completion=True,
poke_interval=10,
allowed_states=[State.QUEUED],
deferrable=False,
dag=self.dag,
)
task.run(start_date=execution_date, end_date=execution_date)

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
execution_date=execution_date,
wait_for_completion=True,
poke_interval=10,
allowed_states=[State.QUEUED],
deferrable=True,
dag=self.dag,
)

task.run(start_date=execution_date, end_date=execution_date)

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})

def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
execution_date=execution_date,
wait_for_completion=True,
poke_interval=10,
allowed_states=[State.SUCCESS],
deferrable=True,
dag=self.dag,
)

task.run(start_date=execution_date, end_date=execution_date)

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

with pytest.raises(AirflowException):
task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})

def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
execution_date=execution_date,
wait_for_completion=True,
poke_interval=10,
allowed_states=[State.SUCCESS],
failed_states=[State.QUEUED],
deferrable=True,
dag=self.dag,
)

task.run(start_date=execution_date, end_date=execution_date)

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

with pytest.raises(AirflowException):
task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})