Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing task_group Sensor in ExternalTaskSensor #24902

Merged
merged 14 commits into from
Aug 22, 2022
Merged
17 changes: 15 additions & 2 deletions airflow/example_dags/example_external_task_marker_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,18 @@
mode="reschedule",
)
# [END howto_operator_external_task_sensor]
child_task2 = EmptyOperator(task_id="child_task2")
child_task1 >> child_task2

# [START howto_operator_external_task_sensor_with_task_group]
child_task2 = ExternalTaskSensor(
task_id="child_task2",
external_dag_id=parent_dag.dag_id,
external_task_group_id='parent_dag_task_group_id',
timeout=600,
allowed_states=['success'],
failed_states=['failed', 'skipped'],
mode="reschedule",
)
# [END howto_operator_external_task_sensor_with_task_group]

child_task3 = EmptyOperator(task_id="child_task3")
child_task1 >> child_task2 >> child_task3
7 changes: 7 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,13 @@ def filter_task_group(group, parent_group):
def has_task(self, task_id: str):
return task_id in self.task_dict

def has_task_group(self, task_group_id: str) -> bool:
return task_group_id in self.task_group_dict

@cached_property
def task_group_dict(self):
return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None}

def get_task(self, task_id: str, include_subdags: bool = False) -> Operator:
if task_id in self.task_dict:
return self.task_dict[task_id]
Expand Down
125 changes: 92 additions & 33 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from airflow.utils.session import provide_session
from airflow.utils.state import State

if TYPE_CHECKING:
from sqlalchemy.orm import Query


class ExternalDagLink(BaseOperatorLink):
"""
Expand All @@ -54,9 +57,13 @@ def get_link(self, operator, dttm):

class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete for a
Waits for a different DAG, a task group, or a task in a different DAG to complete for a
specific logical date.

If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor
waits for the DAG.
Values for `external_task_group_id` and `external_task_id` can't be set at the same time.

By default the ExternalTaskSensor will wait for the external task to
succeed, at which point it will also succeed. However, by default it will
*not* fail if the external task fails, but will continue to check the status
Expand All @@ -78,7 +85,7 @@ class ExternalTaskSensor(BaseSensorOperator):
:param external_dag_id: The dag_id that contains the task you want to
wait for
:param external_task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
wait for.
:param external_task_ids: The list of task_ids that you want to wait for.
If ``None`` (default value) the sensor waits for the DAG. Either
external_task_id or external_task_ids can be passed to
Expand Down Expand Up @@ -111,6 +118,7 @@ def __init__(
external_dag_id: str,
external_task_id: Optional[str] = None,
external_task_ids: Optional[Collection[str]] = None,
external_task_group_id: Optional[str] = None,
allowed_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
execution_delta: Optional[datetime.timedelta] = None,
Expand Down Expand Up @@ -139,18 +147,25 @@ def __init__(
if external_task_id is not None:
external_task_ids = [external_task_id]

if external_task_ids:
if external_task_group_id and external_task_ids:
pateash marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Values for `external_task_group_id` and `external_task_id` or `external_task_ids` "
"can't be set at the same time"
)

if external_task_ids or external_task_group_id:
if not total_states <= set(State.task_states):
raise ValueError(
f'Valid values for `allowed_states` and `failed_states` '
f'when `external_task_id` or `external_task_ids` is not `None`: {State.task_states}'
f'when `external_task_id` or `external_task_ids` or `external_task_group_id` '
f'is not `None`: {State.task_states}'
)
if len(external_task_ids) > len(set(external_task_ids)):
if external_task_ids and len(external_task_ids) > len(set(external_task_ids)):
raise ValueError('Duplicate task_ids passed in external_task_ids parameter')
elif not total_states <= set(State.dag_states):
raise ValueError(
f'Valid values for `allowed_states` and `failed_states` '
f'when `external_task_id` is `None`: {State.dag_states}'
f'when `external_task_id` and `external_task_group_id` is `None`: {State.dag_states}'
)

if execution_delta is not None and execution_date_fn is not None:
Expand All @@ -164,27 +179,39 @@ def __init__(
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
self.external_task_ids = external_task_ids
self.external_task_group_id = external_task_group_id
self.check_existence = check_existence
self._has_checked_existence = False

@provide_session
def poke(self, context, session=None):
def _get_dttm_filter(self, context):
if self.execution_delta:
dttm = context['logical_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
dttm = context['logical_date']
return dttm if isinstance(dttm, list) else [dttm]

dttm_filter = dttm if isinstance(dttm, list) else [dttm]
@provide_session
def poke(self, context, session=None):
dttm_filter = self._get_dttm_filter(context)
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)

self.log.info(
'Poking for tasks %s in dag %s on %s ... ',
self.external_task_ids,
self.external_dag_id,
serialized_dttm_filter,
)
if self.external_task_ids:
self.log.info(
'Poking for tasks %s in dag %s on %s ... ',
self.external_task_ids,
self.external_dag_id,
serialized_dttm_filter,
)

if self.external_task_group_id:
self.log.info(
"Poking for task_group '%s' in dag '%s' on %s ... ",
self.external_task_group_id,
self.external_dag_id,
serialized_dttm_filter,
)

# In poke mode this will check dag existence only once
if self.check_existence and not self._has_checked_existence:
Expand All @@ -207,6 +234,17 @@ def poke(self, context, session=None):
f'Some of the external tasks {self.external_task_ids} '
f'in DAG {self.external_dag_id} failed.'
)
elif self.external_task_group_id:
if self.soft_fail:
raise AirflowSkipException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail."
)
raise AirflowException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed."
)

else:
if self.soft_fail:
raise AirflowSkipException(
Expand All @@ -217,7 +255,7 @@ def poke(self, context, session=None):
return count_allowed == len(dttm_filter)

def _check_for_existence(self, session) -> None:
dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first()
dag_to_wait = DagModel.get_current(self.external_dag_id, session)

if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')
Expand All @@ -233,6 +271,15 @@ def _check_for_existence(self, session) -> None:
f'The external task {external_task_id} in '
f'DAG {self.external_dag_id} does not exist.'
)

if self.external_task_group_id:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
if not refreshed_dag_info.has_task_group(self.external_task_group_id):
raise AirflowException(
f"The external task group '{self.external_task_group_id}' in "
f"DAG '{self.external_dag_id}' does not exist."
)

self._has_checked_existence = True

def get_count(self, dttm_filter, session, states) -> int:
Expand All @@ -251,28 +298,40 @@ def get_count(self, dttm_filter, session, states) -> int:

if self.external_task_ids:
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TI.dag_id == self.external_dag_id,
TI.task_id.in_(self.external_task_ids),
TI.state.in_(states),
TI.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(self.external_task_ids))
.scalar()
)
count = count / len(self.external_task_ids)
else:
) / len(self.external_task_ids)
elif self.external_task_group_id:
external_task_group_task_ids = self.get_external_task_group_task_ids(session)
count = (
session.query(func.count())
.filter(
DR.dag_id == self.external_dag_id,
DR.state.in_(states),
DR.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(external_task_group_task_ids))
.scalar()
)
) / len(external_task_group_task_ids)
else:
count = self._count_query(DR, session, states, dttm_filter).scalar()
return count

def _count_query(self, model, session, states, dttm_filter) -> "Query":
query = session.query(func.count()).filter(
model.dag_id == self.external_dag_id,
model.state.in_(states), # pylint: disable=no-member
model.execution_date.in_(dttm_filter),
)
return query

def get_external_task_group_task_ids(self, session):
refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id)

if task_group:
return [task.task_id for task in task_group]

# returning default task_id as group_id itself, this will avoid any failure in case of
# 'check_existence=False' and will fail on timeout
return [self.external_task_group_id]

def _handle_execution_date_fn(self, context) -> Any:
"""
This function is to handle backwards compatibility with how this operator was
Expand Down
10 changes: 10 additions & 0 deletions docs/apache-airflow/howto/operator/external_task_sensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ via ``allowed_states`` and ``failed_states`` parameters.
:start-after: [START howto_operator_external_task_sensor]
:end-before: [END howto_operator_external_task_sensor]

ExternalTaskSensor with task_group dependency
---------------------------------------------
In Addition, we can also use the :class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a DAG
wait for another ``task_group`` on a different DAG for a specific ``execution_date``.
eladkal marked this conversation as resolved.
Show resolved Hide resolved
eladkal marked this conversation as resolved.
Show resolved Hide resolved

.. exampleinclude:: /../../airflow/example_dags/example_external_task_marker_dag.py
:language: python
:dedent: 4
:start-after: [START howto_operator_external_task_sensor_with_task_group]
:end-before: [END howto_operator_external_task_sensor_with_task_group]


ExternalTaskMarker
Expand Down
Loading