Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
Expand Down Expand Up @@ -244,7 +245,9 @@ def ti_run(
)

if dag := dag_bag.get_dag(ti.dag_id):
upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
upstream_map_indexes = dict(
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
)
else:
upstream_map_indexes = None

Expand Down Expand Up @@ -274,7 +277,7 @@ def ti_run(


def _get_upstream_map_indexes(
task: Operator, ti_map_index: int
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
) -> Iterator[tuple[str, int | list[int] | None]]:
for upstream_task in task.upstream_list:
map_indexes: int | list[int] | None
Expand All @@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
map_indexes = ti_map_index
else:
# tasks not in the same mapped task group
# the upstream mapped task group should combine the xcom as a list and return it
mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count()
# the upstream mapped task group should combine the return xcom as a list and return it
mapped_ti_count: int
upstream_mapped_group = upstream_task.task_group
try:
# for cases that does not need to resolve xcom
mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count()
except NotFullyPopulated:
# for cases that needs to resolve xcom to get the correct count
mapped_ti_count = upstream_mapped_group._expand_input.get_total_map_length(
run_id, session=session
)
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None

yield upstream_task.task_id, map_indexes
Expand Down
13 changes: 9 additions & 4 deletions airflow-core/src/airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import functools
import operator
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, Union

import attrs

Expand All @@ -32,7 +32,6 @@

from airflow.sdk.definitions._internal.expandinput import (
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
MappedArgument,
NotFullyPopulated,
Expand Down Expand Up @@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg
class SchedulerDictOfListsExpandInput:
value: dict

EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"

def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v))
Expand Down Expand Up @@ -114,6 +115,8 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
class SchedulerListOfDictsExpandInput:
value: list

EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"

def get_parse_time_mapped_ti_count(self) -> int:
if isinstance(self.value, Sized):
return len(self.value)
Expand All @@ -130,11 +133,13 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
return length


_EXPAND_INPUT_TYPES = {
_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
"dict-of-lists": SchedulerDictOfListsExpandInput,
"list-of-dicts": SchedulerListOfDictsExpandInput,
}

SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput]


def create_expand_input(kind: str, value: Any) -> ExpandInput:
def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
return _EXPAND_INPUT_TYPES[kind](value)
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from inspect import Parameter

from airflow.models import DagRun
from airflow.models.expandinput import ExpandInput
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
Expand Down Expand Up @@ -557,7 +557,7 @@ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
possible ExpandInput cases.
"""

def deref(self, dag: DAG) -> ExpandInput:
def deref(self, dag: DAG) -> SchedulerExpandInput:
"""
De-reference into a concrete ExpandInput object.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import TaskGroup
from airflow.sdk import TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

Expand Down Expand Up @@ -237,6 +237,128 @@ def test_ti_run_state_to_running(
)
assert response.status_code == 409

def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances
"""

with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True):

@task_group
def task_group_1(arg1):
@task
def group1_task_1(arg1):
return {"a": arg1}

@task
def group1_task_2(arg2):
return arg2

group1_task_2(group1_task_1(arg1))

@task
def task2():
return None

task_group_1.expand(arg1=[0, 1]) >> task2()

dr = dag_maker.create_dagrun()
for ti in dr.get_task_instances():
ti.set_state(State.QUEUED)
dag_maker.session.flush()

# key: (task_id, map_index)
# value: result upstream_map_indexes ({task_id: map_indexes})
expected_upstream_map_indexes = {
# no upstream task for task_group_1.group_task_1
("task_group_1.group1_task_1", 0): {},
("task_group_1.group1_task_1", 1): {},
# the upstream task for task_group_1.group_task_2 is task_group_1.group_task_2
# since they are in the same task group, the upstream map index should be the same as the task
("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 0},
("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 1},
# the upstream task for task2 is the last tasks of task_group_1, which is
# task_group_1.group_task_2
# since they are not in the same task group, the upstream map index should include all the
# expanded tasks
("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
}

for ti in dr.get_task_instances():
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
upstream_map_indexes = response.json()["upstream_map_indexes"]
assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)]

def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom
"""
from airflow.models.taskmap import TaskMap

with dag_maker(session=session):

@task
def task_1():
return [0, 1]

@task_group
def tg(x, y):
@task
def task_2():
pass

task_2()

@task
def task_3():
pass

tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()

dr = dag_maker.create_dagrun()

decision = dr.task_instance_scheduling_decisions(session=session)

# Simulate task_1 execution to produce TaskMap.
(ti_1,) = decision.schedulable_tis
# ti_1 = dr.get_task_instance(task_id="task_1")
ti_1.state = TaskInstanceState.SUCCESS
session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
session.flush()

# Now task_2 in mapped tagk group is expanded.
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.state = TaskInstanceState.SUCCESS
session.flush()

decision = dr.task_instance_scheduling_decisions(session=session)
(task_3_ti,) = decision.schedulable_tis
task_3_ti.set_state(State.QUEUED)

response = client.patch(
f"/execution/task-instances/{task_3_ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)
assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 2, 3, 4, 5]}

def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine):
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.definitions._internal.expandinput import ExpandInput
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.xcom_arg import XComArg
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from airflow.models.expandinput import ExpandInput
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.mixins import DependencyMixin
Expand Down Expand Up @@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
a ``@task_group`` function instead.
"""

def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input

Expand Down