Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalTaskSensor respects soft_fail if the external task enters a failed_state #23647

Merged
merged 4 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import attr
from sqlalchemy import func

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.baseoperator import BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
Expand Down Expand Up @@ -57,6 +57,24 @@ class ExternalTaskSensor(BaseSensorOperator):
Waits for a different DAG or a task in a different DAG to complete for a
specific logical date.

By default the ExternalTaskSensor will wait for the external task to
succeed, at which point it will also succeed. However, by default it will
*not* fail if the external task fails, but will continue to check the status
until the sensor times out (thus giving you time to retry the external task
without also having to clear the sensor).

It is possible to alter the default behavior by setting states which
cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
sensor which goes green when the external task *fails* and immediately goes
red if the external task *succeeds*!

Note that ``soft_fail`` is respected when examining the failed_states. Thus
if the external task enters a failed state and ``soft_fail == True`` the
sensor will _skip_ rather than fail. As a result, setting ``soft_fail=True``
and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
the external task skips.

:param external_dag_id: The dag_id that contains the task you want to
wait for
:param external_task_id: The task_id that contains the task you want to
Expand Down Expand Up @@ -184,11 +202,20 @@ def poke(self, context, session=None):

if count_failed == len(dttm_filter):
if self.external_task_ids:
if self.soft_fail:
raise AirflowSkipException(
f'Some of the external tasks {self.external_task_ids} '
f'in DAG {self.external_dag_id} failed. Skipping due to soft_fail.'
)
raise AirflowException(
f'Some of the external tasks {self.external_task_ids} '
f'in DAG {self.external_dag_id} failed.'
)
else:
if self.soft_fail:
raise AirflowSkipException(
f'The external DAG {self.external_dag_id} failed. Skipping due to soft_fail.'
)
raise AirflowException(f'The external DAG {self.external_dag_id} failed.')

return count_allowed == len(dttm_filter)
Expand Down
1 change: 1 addition & 0 deletions newsfragments/23647.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``ExternalTaskSensor`` now supports the ``soft_fail`` flag to skip if external task or DAG enters a failed state.
57 changes: 49 additions & 8 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,36 @@ def test_external_task_sensor_failed_states_as_success(self):
"unit_test_dag failed."
)

def test_external_task_sensor_soft_fail_failed_states_as_skipped(self, session=None):
self.test_time_sensor()
op = ExternalTaskSensor(
task_id='test_external_task_sensor_check',
external_dag_id=TEST_DAG_ID,
external_task_id=TEST_TASK_ID,
allowed_states=[State.FAILED],
failed_states=[State.SUCCESS],
soft_fail=True,
dag=self.dag,
)

# when
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

# then
session = settings.Session()
TI = TaskInstance
task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
assert len(task_instances) == 1, "Unexpected number of task instances"
assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"

def test_external_task_sensor_external_task_id_param(self):
"""Test external_task_ids is set properly when external_task_id is passed as a template"""
self.test_time_sensor()
op = ExternalTaskSensor(
task_id='test_external_task_sensor_check',
external_dag_id='{{ params.dag_id }}',
external_task_id='{{ params.task_id }}',
params={
'dag_id': TEST_DAG_ID,
'task_id': TEST_TASK_ID,
},
params={'dag_id': TEST_DAG_ID, 'task_id': TEST_TASK_ID},
dag=self.dag,
)

Expand All @@ -162,10 +181,7 @@ def test_external_task_sensor_external_task_ids_param(self):
task_id='test_external_task_sensor_check',
external_dag_id='{{ params.dag_id }}',
external_task_ids=['{{ params.task_id }}'],
params={
'dag_id': TEST_DAG_ID,
'task_id': TEST_TASK_ID,
},
params={'dag_id': TEST_DAG_ID, 'task_id': TEST_TASK_ID},
dag=self.dag,
)

Expand Down Expand Up @@ -214,6 +230,31 @@ def test_external_dag_sensor(self):
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

def test_external_dag_sensor_soft_fail_as_skipped(self):
other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once')
other_dag.create_dagrun(
run_id='test', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.SUCCESS
)
op = ExternalTaskSensor(
task_id='test_external_dag_sensor_check',
external_dag_id='other_dag',
external_task_id=None,
allowed_states=[State.FAILED],
failed_states=[State.SUCCESS],
soft_fail=True,
dag=self.dag,
)

# when
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

# then
session = settings.Session()
TI = TaskInstance
task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
assert len(task_instances) == 1, "Unexpected number of task instances"
assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"

def test_external_task_sensor_fn_multiple_execution_dates(self):
bash_command_code = """
{% set s=logical_date.time().second %}
Expand Down