Skip to content

Commit

Permalink
Fix scheduler crash when expanding with mapped task that returned none (
Browse files Browse the repository at this point in the history
#23486)

When task is expanded from a mapped task that returned no value, it
crashes the scheduler. This PR fixes it by first checking if there's
a return value from the mapped task, if no returned value, then error
in the task itself instead of crashing the scheduler

(cherry picked from commit 7813f99)
  • Loading branch information
ephraimbuddy committed May 17, 2022
1 parent b369085 commit cceccf2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
4 changes: 3 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,10 +2329,12 @@ def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, sess
# currently possible for a downstream to depend on one individual mapped
# task instance, only a task as a whole. This will change in AIP-42
# Phase 2, and we'll need to further analyze the mapped task case.
if task.is_mapped or next(task.iter_mapped_dependants(), None) is None:
if next(task.iter_mapped_dependants(), None) is None:
return
if value is None:
raise XComForMappingNotPushed()
if task.is_mapped:
return
if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)
task_map = TaskMap.from_task_instance_xcom(self, value)
Expand Down
18 changes: 18 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,3 +2831,21 @@ def add_one(x):

query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session)
assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5]


def test_ti_mapped_depends_on_mapped_xcom_arg_XXX(dag_maker, session):
with dag_maker(session=session) as dag:

@dag.task
def add_one(x):
x + 1

two_three_four = add_one.expand(x=[1, 2, 3])
add_one.expand(x=two_three_four)

dagrun = dag_maker.create_dagrun()
for map_index in range(3):
ti = dagrun.get_task_instance("add_one", map_index=map_index)
ti.refresh_from_task(dag.get_task("add_one"))
with pytest.raises(XComForMappingNotPushed):
ti.run()

0 comments on commit cceccf2

Please sign in to comment.