diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 21d88da4874e6..d4c81ca1f84c7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -758,33 +758,55 @@ def post_clear_task_instances( if future: body.end_date = None - task_ids = body.task_ids - if task_ids is not None: - tasks = set(task_ids) - mapped_tasks_tuples = set(t for t in tasks if isinstance(t, tuple)) + if (task_markers_to_clear := body.task_ids) is not None: + mapped_tasks_tuples = {t for t in task_markers_to_clear if isinstance(t, tuple)} # Unmapped tasks are expressed in their task_ids (without map_indexes) - unmapped_task_ids = set(t for t in tasks if not isinstance(t, tuple)) - - if upstream or downstream: - mapped_task_ids = set(tid for tid, _ in mapped_tasks_tuples) - relatives = dag.partial_subset( - task_ids=unmapped_task_ids | mapped_task_ids, - include_downstream=downstream, - include_upstream=upstream, - exclude_original=True, - ) - unmapped_task_ids = unmapped_task_ids | set(relatives.task_dict.keys()) + normal_task_ids = {t for t in task_markers_to_clear if not isinstance(t, tuple)} + + def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"]) -> None: + from airflow.models.taskinstance import find_relevant_relatives - mapped_tasks_list = [ - (tid, map_id) for tid, map_id in mapped_tasks_tuples if tid not in unmapped_task_ids + relevant_relatives = find_relevant_relatives( + normal_task_ids, + mapped_tasks_tuples, + dag=dag, + run_id=run_id, + direction=direction, + session=session, + ) + normal_task_ids.update(t for t in relevant_relatives if not isinstance(t, tuple)) + mapped_tasks_tuples.update(t for t in relevant_relatives if isinstance(t, tuple)) + + # We can't easily calculate upstream/downstream map indexes when not + # working for a specific dag run. It's possible by looking at the runs + # one by one, but that is both resource-consuming and logically complex. + # So instead we'll just clear all the tis based on task ID and hope + # that's good enough for most cases. + if dag_run_id is None: + if upstream or downstream: + partial_dag = dag.partial_subset( + task_ids=normal_task_ids.union(tid for tid, _ in mapped_tasks_tuples), + include_downstream=downstream, + include_upstream=upstream, + exclude_original=True, + ) + normal_task_ids.update(partial_dag.task_dict) + else: + if upstream: + _collect_relatives(dag_run_id, "upstream") + if downstream: + _collect_relatives(dag_run_id, "downstream") + + task_markers_to_clear = [ + *normal_task_ids, + *((t, m) for t, m in mapped_tasks_tuples if t not in normal_task_ids), ] - task_ids = mapped_tasks_list + list(unmapped_task_ids) if dag_run_id is not None and not (past or future): # Use run_id-based clearing when we have a specific dag_run_id and not using past/future task_instances = dag.clear( dry_run=True, - task_ids=task_ids, + task_ids=task_markers_to_clear, run_id=dag_run_id, session=session, run_on_latest_version=body.run_on_latest_version, @@ -795,7 +817,7 @@ def post_clear_task_instances( # Use date-based clearing when no dag_run_id or when past/future is specified task_instances = dag.clear( dry_run=True, - task_ids=task_ids, + task_ids=task_markers_to_clear, start_date=body.start_date, end_date=body.end_date, session=session, diff --git a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py index 750c3da1ec17b..c7b3a02301daa 100644 --- a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py +++ b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py @@ -22,9 +22,9 @@ # [START example_dynamic_task_mapping] from datetime import datetime -from airflow.sdk import DAG, task +from airflow.sdk import DAG, task, task_group -with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=datetime(2022, 3, 4)) as dag: +with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=datetime(2022, 3, 4)): @task def add_one(x: int): @@ -39,8 +39,11 @@ def sum_it(values): sum_it(added_values) with DAG( - dag_id="example_task_mapping_second_order", schedule=None, catchup=False, start_date=datetime(2022, 3, 4) -) as dag2: + dag_id="example_task_mapping_second_order", + schedule=None, + catchup=False, + start_date=datetime(2022, 3, 4), +): @task def get_nums(): @@ -58,4 +61,25 @@ def add_10(num): _times_2 = times_2.expand(num=_get_nums) add_10.expand(num=_times_2) +with DAG( + dag_id="example_task_group_mapping", + schedule=None, + catchup=False, + start_date=datetime(2022, 3, 4), +): + + @task_group + def op(num): + @task + def add_1(num): + return num + 1 + + @task + def mul_2(num): + return num * 2 + + return mul_2(add_1(num)) + + op.expand(num=[1, 2, 3]) + # [END example_dynamic_task_mapping] diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 81c44eb2b0b68..8e1e188395405 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -28,7 +28,7 @@ from collections.abc import Collection, Iterable from datetime import datetime, timedelta from functools import cache -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote import attrs @@ -119,7 +119,7 @@ from airflow.sdk.definitions.asset import AssetUniqueKey from airflow.sdk.types import RuntimeTaskInstanceProtocol from airflow.serialization.definitions.taskgroup import SerializedTaskGroup - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.context import Context Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -2040,87 +2040,16 @@ def get_relevant_upstream_map_indexes( *, session: Session, ) -> int | range | None: - """ - Infer the map indexes of an upstream "relevant" to this ti. - - The bulk of the logic mainly exists to solve the problem described by - the following example, where 'val' must resolve to different values, - depending on where the reference is being used:: - - @task - def this_task(v): # This is self.task. - return v * 2 - - - @task_group - def tg1(inp): - val = upstream(inp) # This is the upstream task. - this_task(val) # When inp is 1, val here should resolve to 2. - return val - - - # This val is the same object returned by tg1. - val = tg1.expand(inp=[1, 2, 3]) - - - @task_group - def tg2(inp): - another_task(inp, val) # val here should resolve to [2, 4, 6]. - - - tg2.expand(inp=["a", "b"]) - - The surrounding mapped task groups of ``upstream`` and ``self.task`` are - inspected to find a common "ancestor". If such an ancestor is found, - we need to return specific map indexes to pull a partial value from - upstream XCom. - - :param upstream: The referenced upstream task. - :param ti_count: The total count of task instance this task was expanded - by the scheduler, i.e. ``expanded_ti_count`` in the template context. - :return: Specific map index or map indexes to pull, or ``None`` if we - want to "whole" return value (i.e. no mapped task groups involved). - """ - from airflow.models.mappedoperator import get_mapped_ti_count - if TYPE_CHECKING: - assert self.task is not None - - # This value should never be None since we already know the current task - # is in a mapped task group, and should have been expanded, despite that, - # we need to check that it is not None to satisfy Mypy. - # But this value can be 0 when we expand an empty list, for that it is - # necessary to check that ti_count is not 0 to avoid dividing by 0. - if not ti_count: - return None - - # Find the innermost common mapped task group between the current task - # If the current task and the referenced task does not have a common - # mapped task group, the two are in different task mapping contexts - # (like another_task above), and we should use the "whole" value. - common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) - if common_ancestor is None: - return None - - # At this point we know the two tasks share a mapped task group, and we - # should use a "partial" value. Let's break down the mapped ti count - # between the ancestor and further expansion happened inside it. - - ancestor_ti_count = get_mapped_ti_count(common_ancestor, self.run_id, session=session) - ancestor_map_index = self.map_index * ancestor_ti_count // ti_count - - # If the task is NOT further expanded inside the common ancestor, we - # only want to reference one single ti. We must walk the actual DAG, - # and "ti_count == ancestor_ti_count" does not work, since the further - # expansion may be of length 1. - if not _is_further_mapped_inside(upstream, common_ancestor): - return ancestor_map_index - - # Otherwise we need a partial aggregation for values from selected task - # instances in the ancestor's expansion context. - further_count = ti_count // ancestor_ti_count - map_index_start = ancestor_map_index * further_count - return range(map_index_start, map_index_start + further_count) + assert self.task + return _get_relevant_map_indexes( + run_id=self.run_id, + map_index=self.map_index, + ti_count=ti_count, + task=self.task, + relative=upstream, + session=session, + ) def clear_db_references(self, session: Session): """ @@ -2245,6 +2174,159 @@ def _is_further_mapped_inside(operator: Operator, container: SerializedTaskGroup return False +def _get_relevant_map_indexes( + *, + task: Operator, + run_id: str, + map_index: int, + relative: Operator, + ti_count: int | None, + session: Session, +) -> int | range | None: + """ + Infer the map indexes of a relative that's "relevant" to this ti. + + The bulk of the logic mainly exists to solve the problem described by + the following example, where 'val' must resolve to different values, + depending on where the reference is being used:: + + @task + def this_task(v): # This is self.task. + return v * 2 + + + @task_group + def tg1(inp): + val = upstream(inp) # This is the upstream task. + this_task(val) # When inp is 1, val here should resolve to 2. + return val + + + # This val is the same object returned by tg1. + val = tg1.expand(inp=[1, 2, 3]) + + + @task_group + def tg2(inp): + another_task(inp, val) # val here should resolve to [2, 4, 6]. + + + tg2.expand(inp=["a", "b"]) + + The surrounding mapped task groups of ``upstream`` and ``task`` are + inspected to find a common "ancestor". If such an ancestor is found, + we need to return specific map indexes to pull a partial value from + upstream XCom. + + The same logic apply for finding downstream tasks. + + :param task: Current task being inspected. + :param run_id: Current run ID. + :param map_index: Map index of the current task instance. + :param relative: The relative task to find relevant map indexes for. + :param ti_count: The total count of task instance this task was expanded + by the scheduler, i.e. ``expanded_ti_count`` in the template context. + :return: Specific map index or map indexes to pull, or ``None`` if we + want to "whole" return value (i.e. no mapped task groups involved). + """ + from airflow.models.mappedoperator import get_mapped_ti_count + + # This value should never be None since we already know the current task + # is in a mapped task group, and should have been expanded, despite that, + # we need to check that it is not None to satisfy Mypy. + # But this value can be 0 when we expand an empty list, for that it is + # necessary to check that ti_count is not 0 to avoid dividing by 0. + if not ti_count: + return None + + # Find the innermost common mapped task group between the current task + # If the current task and the referenced task does not have a common + # mapped task group, the two are in different task mapping contexts + # (like another_task above), and we should use the "whole" value. + if (common_ancestor := _find_common_ancestor_mapped_group(task, relative)) is None: + return None + + # At this point we know the two tasks share a mapped task group, and we + # should use a "partial" value. Let's break down the mapped ti count + # between the ancestor and further expansion happened inside it. + + ancestor_ti_count = get_mapped_ti_count(common_ancestor, run_id, session=session) + ancestor_map_index = map_index * ancestor_ti_count // ti_count + + # If the task is NOT further expanded inside the common ancestor, we + # only want to reference one single ti. We must walk the actual DAG, + # and "ti_count == ancestor_ti_count" does not work, since the further + # expansion may be of length 1. + if not _is_further_mapped_inside(relative, common_ancestor): + return ancestor_map_index + + # Otherwise we need a partial aggregation for values from selected task + # instances in the ancestor's expansion context. + further_count = ti_count // ancestor_ti_count + map_index_start = ancestor_map_index * further_count + return range(map_index_start, map_index_start + further_count) + + +def find_relevant_relatives( + normal_tasks: Iterable[str], + mapped_tasks: Iterable[tuple[str, int]], + *, + direction: Literal["upstream", "downstream"], + dag: SerializedDAG, + run_id: str, + session: Session, +) -> Collection[str | tuple[str, int]]: + from airflow.models.mappedoperator import get_mapped_ti_count + + visited: set[str | tuple[str, int]] = set() + + def _visit_relevant_relatives_for_normal(task_ids: Iterable[str]) -> None: + partial_dag = dag.partial_subset( + task_ids=task_ids, + include_downstream=direction == "downstream", + include_upstream=direction == "upstream", + exclude_original=True, + ) + visited.update(partial_dag.task_dict) + + def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]) -> None: + for task_id, map_index in mapped_tasks: + task = dag.get_task(task_id) + ti_count = get_mapped_ti_count(task, run_id, session=session) + # TODO (GH-52141): This should return scheduler operator types, but + # currently get_flat_relatives is inherited from SDK DAGNode. + relatives = cast("Iterable[Operator]", task.get_flat_relatives(upstream=direction == "upstream")) + for relative in relatives: + if relative.task_id in visited: + continue + relative_map_indexes = _get_relevant_map_indexes( + task=task, + relative=relative, # type: ignore[arg-type] + run_id=run_id, + map_index=map_index, + ti_count=ti_count, + session=session, + ) + visiting_mapped: set[tuple[str, int]] = set() + visiting_normal: set[str] = set() + match relative_map_indexes: + case int(): + if (item := (relative.task_id, relative_map_indexes)) not in visited: + visiting_mapped.add(item) + case range(): + visiting_mapped.update((relative.task_id, i) for i in relative_map_indexes) + case None: + if (task_id := relative.task_id) not in visited: + visiting_normal.add(task_id) + _visit_relevant_relatives_for_normal(visiting_normal) + _visit_relevant_relatives_for_mapped(visiting_mapped) + visited.update(visiting_mapped, visiting_normal) + + _visit_relevant_relatives_for_normal(normal_tasks) + _visit_relevant_relatives_for_mapped(mapped_tasks) + return visited + + class TaskInstanceNote(Base): """For storage of arbitrary notes concerning the task instance.""" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 36605a7f4abdc..bbe92f58ec19c 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -2705,6 +2705,27 @@ class TestPostClearTaskInstances(TestTaskInstanceEndpoint): 4, id="clear mapped tasks with and without map index", ), + pytest.param( + "example_task_group_mapping", + [ + { + "state": State.FAILED, + "map_indexes": (0, 1, 2), + }, + { + "state": State.FAILED, + "map_indexes": (0, 1, 2), + }, + ], + "example_task_group_mapping", + { + "task_ids": [["op.mul_2", 0]], + "dag_run_id": "TEST_DAG_RUN_ID", + "include_upstream": True, + }, + 2, + id="clear tasks in mapped task group", + ), ], ) def test_should_respond_200( diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index b04ce2b7ffe7f..fa713d43e4d79 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -52,6 +52,7 @@ TaskInstance, TaskInstance as TI, TaskInstanceNote, + find_relevant_relatives, ) from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.models.taskmap import TaskMap @@ -3190,3 +3191,56 @@ def test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio session.delete(version) with pytest.raises(IntegrityError): session.commit() + + +@pytest.mark.parametrize( + ("normal_tasks", "mapped_tasks", "expected"), + [ + # 4 is just a regular task so it depends on all its upstreams. + pytest.param(["4"], [], {"1", "2", "3"}, id="nonmapped"), + # 3 is a mapped; it depends on all tis of the mapped upstream 2. + pytest.param(["3"], [], {"1", "2"}, id="mapped-whole"), + # Every ti of a mapped task depends on all tis of the mapped upstream. + pytest.param([], [("3", 1)], {"1", "2"}, id="mapped-one"), + # Same as the (non-group) unmapped case, d depends on all upstreams. + pytest.param(["d"], [], {"a", "b", "c"}, id="group-nonmapped"), + # This specifies c tis in ALL mapped task groups, so all b tis are needed. + pytest.param(["c"], [], {"a", "b"}, id="group-mapped-whole"), + # This only specifies one c ti, so only one b ti from the same mapped instance is returned. + pytest.param([], [("c", 1)], {"a", ("b", 1)}, id="group-mapped-one"), + ], +) +def test_find_relevant_relatives(dag_maker, session, normal_tasks, mapped_tasks, expected): + # 1 -> 2[] -> 3[] -> 4 + # + # a -> " b --> c " -> d + # "== g[] ==" + with dag_maker(session=session) as dag: + t1 = EmptyOperator(task_id="1") + t2 = MockOperator.partial(task_id="2").expand(arg1=["x", "y"]) + t3 = MockOperator.partial(task_id="3").expand(arg1=["x", "y"]) + t4 = EmptyOperator(task_id="4") + t1 >> t2 >> t3 >> t4 + + ta = EmptyOperator(task_id="a") + + @task_group(prefix_group_id=False) + def g(v): + tb = MockOperator(task_id="b", arg1=v) + tc = MockOperator(task_id="c", arg1=v) + tb >> tc + + td = EmptyOperator(task_id="d") + ta >> g.expand(v=["x", "y", "z"]) >> td + + dr = dag_maker.create_dagrun(state="success") + + result = find_relevant_relatives( + normal_tasks=normal_tasks, + mapped_tasks=mapped_tasks, + direction="upstream", + dag=dag, + run_id=dr.run_id, + session=session, + ) + assert result == expected