Skip to content

Commit

Permalink
feat: soft_fail TriggerDagRunOperator (#39173)
Browse files Browse the repository at this point in the history
* feat: soft_fail TriggerDagRunOperator

* review_1

---------

Co-authored-by: raphaelauv <raphaelauv@users.noreply.github.com>
  • Loading branch information
raphaelauv and raphaelauv authored May 6, 2024
1 parent b594a8d commit eb48911
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
16 changes: 15 additions & 1 deletion airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning
from airflow.exceptions import (
AirflowException,
AirflowSkipException,
DagNotFound,
DagRunAlreadyExists,
RemovedInAirflow3Warning,
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -90,6 +96,7 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param skip_when_already_exists: Set to true to mark the task as SKIPPED if a dag_run already exists
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
:param execution_date: Deprecated parameter; same as ``logical_date``.
Expand All @@ -101,6 +108,7 @@ class TriggerDagRunOperator(BaseOperator):
"logical_date",
"conf",
"wait_for_completion",
"skip_when_already_exists",
)
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"
Expand All @@ -118,6 +126,7 @@ def __init__(
poke_interval: int = 60,
allowed_states: list[str] | None = None,
failed_states: list[str] | None = None,
skip_when_already_exists: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
execution_date: str | datetime.datetime | None = None,
**kwargs,
Expand All @@ -137,6 +146,7 @@ def __init__(
self.failed_states = [DagRunState(s) for s in failed_states]
else:
self.failed_states = [DagRunState.FAILED]
self.skip_when_already_exists = skip_when_already_exists
self._defer = deferrable

if execution_date is not None:
Expand Down Expand Up @@ -196,6 +206,10 @@ def execute(self, context: Context):
dag_run = e.dag_run
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
if self.skip_when_already_exists:
raise AirflowSkipException(
"Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists"
)
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
Expand Down
19 changes: 18 additions & 1 deletion tests/operators/test_trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -322,6 +322,23 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trig
with pytest.raises(DagRunAlreadyExists):
task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True)

def test_trigger_dagrun_with_skip_when_already_exists(self):
"""Test TriggerDagRunOperator with skip_when_already_exists."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="dummy_run_id",
execution_date=None,
reset_dag_run=False,
skip_when_already_exists=True,
dag=self.dag,
)
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED

@pytest.mark.parametrize(
"trigger_run_id, trigger_logical_date, expected_dagruns_count",
[
Expand Down

0 comments on commit eb48911

Please sign in to comment.