From 619112ad5a2f3cb50242d8fcffd6daf51c02ec6f Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Tue, 28 Jun 2022 15:48:25 +0100 Subject: [PATCH 01/13] airflow-14563- task_group checking added --- airflow/models/dag.py | 6 ++++ airflow/sensors/external_task.py | 55 ++++++++++++++++++++++++++------ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 79d5e0e780fd5..981cff265962e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2220,6 +2220,12 @@ def filter_task_group(group, parent_group): def has_task(self, task_id: str): return task_id in self.task_dict + def has_task_group(self, task_group_id: str): + task_groups_ids = set( + task_id.split(".")[0] for task_id in self.task_ids if + "." in task_id) + return task_group_id in task_groups_ids + def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: if task_id in self.task_dict: return self.task_dict[task_id] diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 7081651a6846d..12eea024a4ef9 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -54,9 +54,13 @@ def get_link(self, operator, dttm): class ExternalTaskSensor(BaseSensorOperator): """ - Waits for a different DAG or a task in a different DAG to complete for a + Waits for a different DAG, a task group, or a task in a different DAG to complete for a specific logical date. + If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor + waits for the DAG. + Values for `external_task_group_id` and `external_task_id` can't be set at the same time. + By default the ExternalTaskSensor will wait for the external task to succeed, at which point it will also succeed. However, by default it will *not* fail if the external task fails, but will continue to check the status @@ -78,7 +82,7 @@ class ExternalTaskSensor(BaseSensorOperator): :param external_dag_id: The dag_id that contains the task you want to wait for :param external_task_id: The task_id that contains the task you want to - wait for. If ``None`` (default value) the sensor waits for the DAG + wait for. :param external_task_ids: The list of task_ids that you want to wait for. If ``None`` (default value) the sensor waits for the DAG. Either external_task_id or external_task_ids can be passed to @@ -111,6 +115,7 @@ def __init__( external_dag_id: str, external_task_id: Optional[str] = None, external_task_ids: Optional[Collection[str]] = None, + external_task_group_id: Optional[str] = None, allowed_states: Optional[Iterable[str]] = None, failed_states: Optional[Iterable[str]] = None, execution_delta: Optional[datetime.timedelta] = None, @@ -139,6 +144,12 @@ def __init__( if external_task_id is not None: external_task_ids = [external_task_id] + if external_task_group_id and external_task_ids: + raise ValueError( + "Values for `external_task_group_id` and `external_task_id` or `external_task_ids` " + "can't be set at the same time" + ) + if external_task_ids: if not total_states <= set(State.task_states): raise ValueError( @@ -164,19 +175,22 @@ def __init__( self.external_dag_id = external_dag_id self.external_task_id = external_task_id self.external_task_ids = external_task_ids + self.external_task_group_id = external_task_group_id self.check_existence = check_existence self._has_checked_existence = False - @provide_session - def poke(self, context, session=None): + def _get_dttm_filter(self, context): if self.execution_delta: dttm = context['logical_date'] - self.execution_delta elif self.execution_date_fn: dttm = self._handle_execution_date_fn(context=context) else: dttm = context['logical_date'] + return dttm if isinstance(dttm, list) else [dttm] - dttm_filter = dttm if isinstance(dttm, list) else [dttm] + @provide_session + def poke(self, context, session=None): + dttm_filter = self._get_dttm_filter(context) serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter) self.log.info( @@ -207,6 +221,17 @@ def poke(self, context, session=None): f'Some of the external tasks {self.external_task_ids} ' f'in DAG {self.external_dag_id} failed.' ) + elif self.external_task_group_id: + if self.soft_fail: + raise AirflowSkipException( + f"The external task group {self.external_task_group_id}" + f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." + ) + raise AirflowException( + f"The external task group {self.external_task_group_id}" + f"in DAG {self.external_dag_id} failed.'" + ) + else: if self.soft_fail: raise AirflowSkipException( @@ -217,7 +242,7 @@ def poke(self, context, session=None): return count_allowed == len(dttm_filter) def _check_for_existence(self, session) -> None: - dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first() + dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.') @@ -233,6 +258,16 @@ def _check_for_existence(self, session) -> None: f'The external task {external_task_id} in ' f'DAG {self.external_dag_id} does not exist.' ) + + if self.external_task_group_id: + refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) + + if not refreshed_dag_info.has_task_group(self.external_task_group_id): + raise AirflowException( + f'The external task group {self.external_task_group_id} in ' + f'DAG {self.external_dag_id} does not exist.' + ) + self._has_checked_existence = True def get_count(self, dttm_filter, session, states) -> int: @@ -252,24 +287,24 @@ def get_count(self, dttm_filter, session, states) -> int: if self.external_task_ids: count = ( session.query(func.count()) # .count() is inefficient - .filter( + .filter( TI.dag_id == self.external_dag_id, TI.task_id.in_(self.external_task_ids), TI.state.in_(states), TI.execution_date.in_(dttm_filter), ) - .scalar() + .scalar() ) count = count / len(self.external_task_ids) else: count = ( session.query(func.count()) - .filter( + .filter( DR.dag_id == self.external_dag_id, DR.state.in_(states), DR.execution_date.in_(dttm_filter), ) - .scalar() + .scalar() ) return count From edb0204ef59b0de4cace6f9bd518917511f6dde7 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Tue, 28 Jun 2022 17:28:59 +0100 Subject: [PATCH 02/13] airflow-14563- task_group sensor working --- airflow/models/dag.py | 9 ++-- airflow/sensors/external_task.py | 90 +++++++++++++++++++++++--------- 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 981cff265962e..29dcbe2b39c57 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2221,10 +2221,11 @@ def has_task(self, task_id: str): return task_id in self.task_dict def has_task_group(self, task_group_id: str): - task_groups_ids = set( - task_id.split(".")[0] for task_id in self.task_ids if - "." in task_id) - return task_group_id in task_groups_ids + return task_group_id in self.task_group_dict + + @property + def task_group_dict(self) -> Dict[str, "TaskGroup"]: + return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: if task_id in self.task_dict: diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 12eea024a4ef9..2f31ded9ae64c 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -36,6 +36,9 @@ from airflow.utils.session import provide_session from airflow.utils.state import State +if TYPE_CHECKING: + from sqlalchemy.orm import Query + class ExternalDagLink(BaseOperatorLink): """ @@ -193,12 +196,21 @@ def poke(self, context, session=None): dttm_filter = self._get_dttm_filter(context) serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter) - self.log.info( - 'Poking for tasks %s in dag %s on %s ... ', - self.external_task_ids, - self.external_dag_id, - serialized_dttm_filter, - ) + if self.external_task_ids: + self.log.info( + 'Poking for tasks %s in dag %s on %s ... ', + self.external_task_ids, + self.external_dag_id, + serialized_dttm_filter, + ) + + if self.external_task_group_id: + self.log.info( + 'Poking for task_group %s in dag %s on %s ... ', + self.external_task_group_id, + self.external_dag_id, + serialized_dttm_filter, + ) # In poke mode this will check dag existence only once if self.check_existence and not self._has_checked_existence: @@ -261,12 +273,15 @@ def _check_for_existence(self, session) -> None: if self.external_task_group_id: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) - + self.log.info("ashish") + self.log.info(str(refreshed_dag_info.task_group_dict)) if not refreshed_dag_info.has_task_group(self.external_task_group_id): raise AirflowException( f'The external task group {self.external_task_group_id} in ' f'DAG {self.external_dag_id} does not exist.' ) + else: # remove this + self.log.info("%s exists in %s", self.external_task_group_id, self.external_dag_id) self._has_checked_existence = True @@ -286,28 +301,51 @@ def get_count(self, dttm_filter, session, states) -> int: if self.external_task_ids: count = ( - session.query(func.count()) # .count() is inefficient - .filter( - TI.dag_id == self.external_dag_id, - TI.task_id.in_(self.external_task_ids), - TI.state.in_(states), - TI.execution_date.in_(dttm_filter), - ) - .scalar() - ) - count = count / len(self.external_task_ids) - else: + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(self.external_task_ids)) + .scalar() + ) / len(self.external_task_ids) + elif self.external_task_group_id: + external_task_group_task_ids = self.get_external_task_group_task_ids(session) + self.log.info(str(external_task_group_task_ids)) + # we need list of task_ids for this task_group count = ( - session.query(func.count()) - .filter( - DR.dag_id == self.external_dag_id, - DR.state.in_(states), - DR.execution_date.in_(dttm_filter), - ) - .scalar() - ) + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(external_task_group_task_ids)) + .scalar() + ) / len(external_task_group_task_ids) + else: + count = self._count_query(DR, session, states, dttm_filter).scalar() return count + def _count_query(self, model, session, states, dttm_filter) -> "Query": + query = session.query(func.count()).filter( + model.dag_id == self.external_dag_id, + model.state.in_(states), # pylint: disable=no-member + model.execution_date.in_(dttm_filter), + ) + return query + + # def get_external_task_group_task_ids(self, session): + # """Return task ids for the external TaskGroup""" + # refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) + # task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get( + # self.external_task_group_id + # ) + # if not task_group: + # raise AirflowException( + # f"The external task group {self.external_task_group_id} in " + # f"DAG {self.external_dag_id} does not exist." + # ) + # task_ids = [task.task_id for task in task_group] + # return task_ids + + def get_external_task_group_task_ids(self, session): + refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) + task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id) + task_ids = [task.task_id for task in task_group] + return task_ids + def _handle_execution_date_fn(self, context) -> Any: """ This function is to handle backwards compatibility with how this operator was From d7a36fe66557fb23abfd78391bd8805395aa55cd Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Tue, 28 Jun 2022 18:20:42 +0100 Subject: [PATCH 03/13] airflow-14563- task_group sensor working --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 29dcbe2b39c57..1b0d31010a184 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2224,7 +2224,7 @@ def has_task_group(self, task_group_id: str): return task_group_id in self.task_group_dict @property - def task_group_dict(self) -> Dict[str, "TaskGroup"]: + def task_group_dict(self): return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: From 98d61d79fd70d324e40faba909b57aae31f941bb Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Tue, 28 Jun 2022 18:22:21 +0100 Subject: [PATCH 04/13] airflow-14563- task_group sensor working --- airflow/sensors/external_task.py | 43 +++++++++++--------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 2f31ded9ae64c..a1193ac54e0dd 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -206,7 +206,7 @@ def poke(self, context, session=None): if self.external_task_group_id: self.log.info( - 'Poking for task_group %s in dag %s on %s ... ', + "Poking for task_group '%s' in dag '%s' on %s ... ", self.external_task_group_id, self.external_dag_id, serialized_dttm_filter, @@ -236,12 +236,12 @@ def poke(self, context, session=None): elif self.external_task_group_id: if self.soft_fail: raise AirflowSkipException( - f"The external task group {self.external_task_group_id}" - f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." + f"The external task_group '{self.external_task_group_id}'" + f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." ) raise AirflowException( - f"The external task group {self.external_task_group_id}" - f"in DAG {self.external_dag_id} failed.'" + f"The external task_group '{self.external_task_group_id}'" + f"in DAG '{self.external_dag_id}' failed.'" ) else: @@ -273,15 +273,11 @@ def _check_for_existence(self, session) -> None: if self.external_task_group_id: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) - self.log.info("ashish") - self.log.info(str(refreshed_dag_info.task_group_dict)) if not refreshed_dag_info.has_task_group(self.external_task_group_id): raise AirflowException( - f'The external task group {self.external_task_group_id} in ' - f'DAG {self.external_dag_id} does not exist.' + f"The external task group '{self.external_task_group_id}' in " + f"DAG '{self.external_dag_id}' does not exist." ) - else: # remove this - self.log.info("%s exists in %s", self.external_task_group_id, self.external_dag_id) self._has_checked_existence = True @@ -307,8 +303,6 @@ def get_count(self, dttm_filter, session, states) -> int: ) / len(self.external_task_ids) elif self.external_task_group_id: external_task_group_task_ids = self.get_external_task_group_task_ids(session) - self.log.info(str(external_task_group_task_ids)) - # we need list of task_ids for this task_group count = ( self._count_query(TI, session, states, dttm_filter) .filter(TI.task_id.in_(external_task_group_task_ids)) @@ -326,25 +320,16 @@ def _count_query(self, model, session, states, dttm_filter) -> "Query": ) return query - # def get_external_task_group_task_ids(self, session): - # """Return task ids for the external TaskGroup""" - # refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) - # task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get( - # self.external_task_group_id - # ) - # if not task_group: - # raise AirflowException( - # f"The external task group {self.external_task_group_id} in " - # f"DAG {self.external_dag_id} does not exist." - # ) - # task_ids = [task.task_id for task in task_group] - # return task_ids - def get_external_task_group_task_ids(self, session): refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id) - task_ids = [task.task_id for task in task_group] - return task_ids + + if task_group: + return [task.task_id for task in task_group] + + # returning default task_id as group_id itself, this will avoid any failure in case of + # 'check_existence=False' and will fail on timeout + return [self.external_task_group_id] def _handle_execution_date_fn(self, context) -> Any: """ From 8a0b730ea04c5efd3b048f9e18d4696cb59f46eb Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 10:56:10 +0100 Subject: [PATCH 05/13] airflow-14563- task_group sensor working --- tests/sensors/test_external_task_sensor.py | 105 +++++++++++++++++---- 1 file changed, 87 insertions(+), 18 deletions(-) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index c1505dc13d646..853d0e8092efa 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -25,13 +25,16 @@ from airflow.exceptions import AirflowException, AirflowSensorTimeout from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator +from airflow.operators.dummy import DummyOperator from airflow.operators.empty import EmptyOperator from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink from airflow.sensors.time_sensor import TimeSensor from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.task_group import TaskGroup from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_runs @@ -40,6 +43,7 @@ TEST_DAG_ID = 'unit_test_dag' TEST_TASK_ID = 'time_sensor_check' TEST_TASK_ID_ALTERNATE = 'time_sensor_check_alternate' +TEST_TASK_GROUP_ID = 'time_sensor_group_id' DEV_NULL = '/dev/null' @@ -54,12 +58,24 @@ def setUp(self): self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) - def test_time_sensor(self, task_id=TEST_TASK_ID): + def add_time_sensor(self, task_id=TEST_TASK_ID): op = TimeSensor(task_id=task_id, target_time=time(0), dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def add_dummy_task_group(self, target_states=None): + target_states = [State.SUCCESS] * 2 if target_states is None else target_states + with self.dag as dag: + with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group: + _ = [DummyOperator(task_id=f"task{i}") for i in range(len(target_states))] + SerializedDagModel.write_dag(dag) + + for idx, task in enumerate(task_group): + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.run(ignore_ti_state=True, mark_success=True) + ti.set_state(target_states[idx]) + def test_external_task_sensor(self): - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -69,8 +85,8 @@ def test_external_task_sensor(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_multiple_task_ids(self): - self.test_time_sensor(task_id=TEST_TASK_ID) - self.test_time_sensor(task_id=TEST_TASK_ID_ALTERNATE) + self.add_time_sensor(task_id=TEST_TASK_ID) + self.add_time_sensor(task_id=TEST_TASK_ID_ALTERNATE) op = ExternalTaskSensor( task_id='test_external_task_sensor_check_task_ids', external_dag_id=TEST_DAG_ID, @@ -79,6 +95,59 @@ def test_external_task_sensor_multiple_task_ids(self): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def test_external_task_sensor_with_task_group(self): + self.add_time_sensor() + self.add_dummy_task_group() + op = ExternalTaskSensor( + task_id='test_external_task_sensor_task_group', + external_dag_id=TEST_DAG_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_raise_with_external_task_sensor_task_group_and_task_id(self): + with pytest.raises(ValueError) as ctx: + ExternalTaskSensor( + task_id='test_external_task_sensor_task_group_with_task_id_failed_status', + external_dag_id=TEST_DAG_ID, + external_task_ids=TEST_TASK_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + dag=self.dag, + ) + assert ( + str(ctx.value) == "Values for `external_task_group_id` and `external_task_id` or " + "`external_task_ids` can't be set at the same time" + ) + + # by default i.e. check_existence=False, if task_group doesn't exist, the sensor will run till timeout, + # this behaviour is similar to external_task_id doesn't exists + def test_external_task_group_not_exists_without_check_existence(self): + self.add_time_sensor() + self.add_dummy_task_group() + with pytest.raises(AirflowException, match=f"Snap. Time is OUT. DAG id: {TEST_DAG_ID}"): + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id='fake-task-group', + timeout=1, + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_external_task_group_sensor_success(self): + self.add_time_sensor() + self.add_dummy_task_group() + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + failed_states=[State.FAILED], + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_catch_overlap_allowed_failed_state(self): with pytest.raises(AirflowException): ExternalTaskSensor( @@ -101,7 +170,7 @@ def test_external_task_sensor_wrong_failed_states(self): ) def test_external_task_sensor_failed_states(self): - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -112,7 +181,7 @@ def test_external_task_sensor_failed_states(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_failed_states_as_success(self): - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -135,7 +204,7 @@ def test_external_task_sensor_failed_states_as_success(self): ) def test_external_task_sensor_soft_fail_failed_states_as_skipped(self, session=None): - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -158,7 +227,7 @@ def test_external_task_sensor_soft_fail_failed_states_as_skipped(self, session=N def test_external_task_sensor_external_task_id_param(self): """Test external_task_ids is set properly when external_task_id is passed as a template""" - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id='{{ params.dag_id }}', @@ -176,7 +245,7 @@ def test_external_task_sensor_external_task_id_param(self): def test_external_task_sensor_external_task_ids_param(self): """Test external_task_ids rendering when a template is passed.""" - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id='{{ params.dag_id }}', @@ -193,8 +262,8 @@ def test_external_task_sensor_external_task_ids_param(self): ) def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self): - self.test_time_sensor(task_id=TEST_TASK_ID) - self.test_time_sensor(task_id=TEST_TASK_ID_ALTERNATE) + self.add_time_sensor(task_id=TEST_TASK_ID) + self.add_time_sensor(task_id=TEST_TASK_ID_ALTERNATE) op = ExternalTaskSensor( task_id='test_external_task_sensor_check_task_ids', external_dag_id=TEST_DAG_ID, @@ -333,7 +402,7 @@ def test_external_task_sensor_fn_multiple_execution_dates(self): task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_delta(self): - self.test_time_sensor() + self.add_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check_delta', external_dag_id=TEST_DAG_ID, @@ -345,7 +414,7 @@ def test_external_task_sensor_delta(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_fn(self): - self.test_time_sensor() + self.add_time_sensor() # check that the execution_fn works op1 = ExternalTaskSensor( task_id='test_external_task_sensor_check_delta_1', @@ -372,7 +441,7 @@ def test_external_task_sensor_fn(self): def test_external_task_sensor_fn_multiple_args(self): """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" - self.test_time_sensor() + self.add_time_sensor() def my_func(dt, context): assert context['logical_date'] == dt @@ -390,7 +459,7 @@ def my_func(dt, context): def test_external_task_sensor_fn_kwargs(self): """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" - self.test_time_sensor() + self.add_time_sensor() def my_func(dt, ds_nodash, tomorrow_ds_nodash): assert ds_nodash == dt.strftime("%Y%m%d") @@ -408,7 +477,7 @@ def my_func(dt, ds_nodash, tomorrow_ds_nodash): op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_error_delta_and_fn(self): - self.test_time_sensor() + self.add_time_sensor() # Test that providing execution_delta and a function raises an error with pytest.raises(ValueError): ExternalTaskSensor( @@ -422,7 +491,7 @@ def test_external_task_sensor_error_delta_and_fn(self): ) def test_external_task_sensor_error_task_id_and_task_ids(self): - self.test_time_sensor() + self.add_time_sensor() # Test that providing execution_delta and a function raises an error with pytest.raises(ValueError): ExternalTaskSensor( @@ -435,7 +504,7 @@ def test_external_task_sensor_error_task_id_and_task_ids(self): ) def test_catch_duplicate_task_ids(self): - self.test_time_sensor() + self.add_time_sensor() # Test By passing same task_id multiple times with pytest.raises(ValueError): ExternalTaskSensor( From fb76d956c4a8151eb4067c356c59e072b4def1a5 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 16:11:50 +0100 Subject: [PATCH 06/13] airflow-14563- including task_group dependency to check task states --- airflow/sensors/external_task.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index a1193ac54e0dd..45e00b52c2fc3 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -153,18 +153,19 @@ def __init__( "can't be set at the same time" ) - if external_task_ids: + if external_task_ids or external_task_group_id: if not total_states <= set(State.task_states): raise ValueError( f'Valid values for `allowed_states` and `failed_states` ' - f'when `external_task_id` or `external_task_ids` is not `None`: {State.task_states}' + f'when `external_task_id` or `external_task_ids` or `external_task_group_id` ' + f'is not `None`: {State.task_states}' ) - if len(external_task_ids) > len(set(external_task_ids)): + if external_task_ids and len(external_task_ids) > len(set(external_task_ids)): raise ValueError('Duplicate task_ids passed in external_task_ids parameter') elif not total_states <= set(State.dag_states): raise ValueError( f'Valid values for `allowed_states` and `failed_states` ' - f'when `external_task_id` is `None`: {State.dag_states}' + f'when `external_task_id` and `external_task_group_id` is `None`: {State.dag_states}' ) if execution_delta is not None and execution_date_fn is not None: From 6360546477e9934287a820ff2ba567ed96760638 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 16:16:31 +0100 Subject: [PATCH 07/13] airflow-14563- task_group sensor working --- airflow/sensors/external_task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 45e00b52c2fc3..19005ea19ec1a 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -237,12 +237,12 @@ def poke(self, context, session=None): elif self.external_task_group_id: if self.soft_fail: raise AirflowSkipException( - f"The external task_group '{self.external_task_group_id}'" + f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." ) raise AirflowException( - f"The external task_group '{self.external_task_group_id}'" - f"in DAG '{self.external_dag_id}' failed.'" + f"The external task_group '{self.external_task_group_id}' " + f"in DAG '{self.external_dag_id}' failed." ) else: From 9bd77083bea7264751b6c646a6bb8e81feab5078 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 16:22:55 +0100 Subject: [PATCH 08/13] airflow-14563- tests added --- tests/sensors/test_external_task_sensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 853d0e8092efa..c3406ebc8dc62 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -147,6 +147,20 @@ def test_external_task_group_sensor_success(self): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def test_external_task_group_sensor_failed_states(self): + ti_states = [State.FAILED, State.FAILED] + self.add_time_sensor() + self.add_dummy_task_group(ti_states) + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + failed_states=[State.FAILED], + dag=self.dag, + ) + with pytest.raises(AirflowException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG " + f"'{TEST_DAG_ID}' failed."): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): with pytest.raises(AirflowException): From e7b148eec574a1c9ab03a7c278a4a816a7152c35 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 16:53:59 +0100 Subject: [PATCH 09/13] airflow-14563- docs added --- .../example_external_task_marker_dag.py | 18 +++++++++++++++--- .../howto/operator/external_task_sensor.rst | 10 ++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py index 88dae819de2b5..0c9bd25fd26c6 100644 --- a/airflow/example_dags/example_external_task_marker_dag.py +++ b/airflow/example_dags/example_external_task_marker_dag.py @@ -69,7 +69,6 @@ catchup=False, tags=['example2'], ) as child_dag: - # [START howto_operator_external_task_sensor] child_task1 = ExternalTaskSensor( task_id="child_task1", external_dag_id=parent_dag.dag_id, @@ -80,5 +79,18 @@ mode="reschedule", ) # [END howto_operator_external_task_sensor] - child_task2 = EmptyOperator(task_id="child_task2") - child_task1 >> child_task2 + + # [START howto_operator_external_task_sensor_with_task_group] + child_task2 = ExternalTaskSensor( + task_id="child_task1", + external_dag_id=parent_dag.dag_id, + external_task_group_id='parent_dag_task_group_id', + timeout=600, + allowed_states=['success'], + failed_states=['failed', 'skipped'], + mode="reschedule", + ) + # [END howto_operator_external_task_sensor_with_task_group] + + child_task3 = EmptyOperator(task_id="child_task3") + child_task1 >> child_task2 >> child_task3 diff --git a/docs/apache-airflow/howto/operator/external_task_sensor.rst b/docs/apache-airflow/howto/operator/external_task_sensor.rst index f6ae421969bec..923f8ec3d1161 100644 --- a/docs/apache-airflow/howto/operator/external_task_sensor.rst +++ b/docs/apache-airflow/howto/operator/external_task_sensor.rst @@ -53,6 +53,16 @@ via ``allowed_states`` and ``failed_states`` parameters. :start-after: [START howto_operator_external_task_sensor] :end-before: [END howto_operator_external_task_sensor] +ExternalTaskSensor with task_group dependency +--------------------------------------------- +In Addition, we can also use the :class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a DAG +wait for another ``task_group`` on a different DAG for a specific ``execution_date``. + +.. exampleinclude:: /../../airflow/example_dags/example_external_task_marker_dag.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_external_task_sensor_with_task_group] + :end-before: [END howto_operator_external_task_sensor_with_task_group] ExternalTaskMarker From b93613122e6738e6ef17493381a35184fb085009 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 17:07:36 +0100 Subject: [PATCH 10/13] airflow-14563- docs added --- airflow/example_dags/example_external_task_marker_dag.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py index 0c9bd25fd26c6..874658d9bbaf5 100644 --- a/airflow/example_dags/example_external_task_marker_dag.py +++ b/airflow/example_dags/example_external_task_marker_dag.py @@ -69,6 +69,7 @@ catchup=False, tags=['example2'], ) as child_dag: + # [START howto_operator_external_task_sensor] child_task1 = ExternalTaskSensor( task_id="child_task1", external_dag_id=parent_dag.dag_id, @@ -82,7 +83,7 @@ # [START howto_operator_external_task_sensor_with_task_group] child_task2 = ExternalTaskSensor( - task_id="child_task1", + task_id="child_task2", external_dag_id=parent_dag.dag_id, external_task_group_id='parent_dag_task_group_id', timeout=600, From ab02251ee05955a8ebbfd87c64f42a58bba346f0 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Thu, 7 Jul 2022 17:18:01 +0100 Subject: [PATCH 11/13] airflow-14563- docs added --- tests/sensors/test_external_task_sensor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index c3406ebc8dc62..09c807ed7bfbe 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -117,7 +117,7 @@ def test_raise_with_external_task_sensor_task_group_and_task_id(self): ) assert ( str(ctx.value) == "Values for `external_task_group_id` and `external_task_id` or " - "`external_task_ids` can't be set at the same time" + "`external_task_ids` can't be set at the same time" ) # by default i.e. check_existence=False, if task_group doesn't exist, the sensor will run till timeout, @@ -158,8 +158,10 @@ def test_external_task_group_sensor_failed_states(self): failed_states=[State.FAILED], dag=self.dag, ) - with pytest.raises(AirflowException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG " - f"'{TEST_DAG_ID}' failed."): + with pytest.raises( + AirflowException, + match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", + ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): From 139cce4a98b35ca2e38b6c06da868b00a30f02b2 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Mon, 18 Jul 2022 15:26:18 +0100 Subject: [PATCH 12/13] airflow-14563- pushing review comments. --- airflow/models/dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1b0d31010a184..944dc6dbc8064 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2220,10 +2220,10 @@ def filter_task_group(group, parent_group): def has_task(self, task_id: str): return task_id in self.task_dict - def has_task_group(self, task_group_id: str): + def has_task_group(self, task_group_id: str) -> bool: return task_group_id in self.task_group_dict - @property + @cached_property def task_group_dict(self): return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} From 26cb8cf27dfca92821c3a51699e4826114822db0 Mon Sep 17 00:00:00 2001 From: Ashish Patel Date: Sun, 14 Aug 2022 21:34:51 +0100 Subject: [PATCH 13/13] airflow-14563- pushing review comments. --- tests/sensors/test_external_task_sensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 09c807ed7bfbe..501224bea40e1 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -27,7 +27,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator -from airflow.operators.dummy import DummyOperator from airflow.operators.empty import EmptyOperator from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink from airflow.sensors.time_sensor import TimeSensor @@ -66,7 +65,7 @@ def add_dummy_task_group(self, target_states=None): target_states = [State.SUCCESS] * 2 if target_states is None else target_states with self.dag as dag: with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group: - _ = [DummyOperator(task_id=f"task{i}") for i in range(len(target_states))] + _ = [EmptyOperator(task_id=f"task{i}") for i in range(len(target_states))] SerializedDagModel.write_dag(dag) for idx, task in enumerate(task_group):