Skip to content

Commit

Permalink
Function to expand mapped tasks in to multiple "real" TIs (#21019)
Browse files Browse the repository at this point in the history
Mark unmapped ti as SKIPPED if upstream map is empty

Co-authored-by: Tzu-ping Chung <tp@astronomer.io>
  • Loading branch information
ashb and uranusjr authored Jan 26, 2022
1 parent 80f30ee commit e45f0e9
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 4 deletions.
69 changes: 69 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import jinja2
import pendulum
from dateutil.relativedelta import relativedelta
from sqlalchemy import or_
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound

Expand All @@ -75,6 +76,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

Expand Down Expand Up @@ -1800,6 +1802,73 @@ def wait_for_downstream(self) -> bool:
def depends_on_past(self) -> bool:
return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream

def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None:
"""Create the mapped TaskInstances for mapped task."""
# TODO: support having multiuple mapped upstreams?
from airflow.models.taskmap import TaskMap
from airflow.settings import task_instance_mutation_hook

task_map_info_length: Optional[int] = (
session.query(TaskMap.length)
.filter_by(
dag_id=upstream_ti.dag_id,
task_id=upstream_ti.task_id,
run_id=upstream_ti.run_id,
map_index=upstream_ti.map_index,
)
.scalar()
)
if task_map_info_length is None:
# TODO: What would lead to this? How can this be better handled?
raise RuntimeError("mapped operator cannot be expanded; upstream not found")

unmapped_ti: Optional[TaskInstance] = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == upstream_ti.dag_id,
TaskInstance.run_id == upstream_ti.run_id,
TaskInstance.task_id == self.task_id,
TaskInstance.map_index == -1,
or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
)
.one_or_none()
)

if unmapped_ti:
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
if task_map_info_length < 1:
# If the upstream maps this to a zero-length value, simply marked the
# unmapped task instance as SKIPPED (if needed).
self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti)
unmapped_ti.state = TaskInstanceState.SKIPPED
session.flush()
return
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
indexes_to_map = range(1, task_map_info_length)
else:
indexes_to_map = range(task_map_info_length)

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
# TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore
task_instance_mutation_hook(ti)
session.merge(ti)

# Set to "REMOVED" any (old) TaskInstances with map indices greater
# than the current map value
session.query(TaskInstance).filter(
TaskInstance.dag_id == upstream_ti.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == upstream_ti.run_id,
TaskInstance.map_index >= task_map_info_length,
).update({TaskInstance.state: TaskInstanceState.REMOVED})

session.flush()


# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def task_filter(task: "BaseOperator"):

def create_ti_mapping(task: "BaseOperator"):
created_counts[task.task_type] += 1
return TI.insert_mapping(self.run_id, task)
return TI.insert_mapping(self.run_id, task, map_index=-1)

else:

Expand Down
3 changes: 2 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def __init__(
self.test_mode = False

@staticmethod
def insert_mapping(run_id: str, task: "BaseOperator") -> dict:
def insert_mapping(run_id: str, task: "BaseOperator", map_index: int) -> dict:
""":meta private:"""
return {
'dag_id': task.dag_id,
Expand All @@ -504,6 +504,7 @@ def insert_mapping(run_id: str, task: "BaseOperator") -> dict:
'max_tries': task.retries,
'executor_config': task.executor_config,
'operator': task.task_type,
'map_index': map_index,
}

@reconstructor
Expand Down
87 changes: 85 additions & 2 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@
chain,
cross_downstream,
)
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
Expand Down Expand Up @@ -733,8 +737,6 @@ def test_map_unknown_arg_raises():

def test_map_xcom_arg():
"""Test that dependencies are correct when mapping with an XComArg"""
from airflow.models.xcom_arg import XComArg

with DAG("test-dag", start_date=DEFAULT_DATE):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
Expand Down Expand Up @@ -767,3 +769,84 @@ def test_partial_on_class_invalid_ctor_args() -> None:
"""
with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
MockOperator.partial(task_id='a', foo='bar', bar=2)


@pytest.mark.parametrize(
["num_existing_tis", "expected"],
(
pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'),
pytest.param(
5,
[(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)],
id="tis-to-be-remove",
),
),
)
def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected):
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
mapped = MockOperator(task_id='task_2').map(arg2=xcomarg)

dr = dag_maker.create_dagrun()

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=len(literal),
keys=None,
)
)

if num_existing_tis:
# Remove the map_index=-1 TI when we're creating other TIs
session.query(TaskInstance).filter(
TaskInstance.dag_id == mapped.dag_id,
TaskInstance.task_id == mapped.task_id,
TaskInstance.run_id == dr.run_id,
).delete()

for index in range(num_existing_tis):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) # type: ignore
session.add(ti)
session.flush()

mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session)

indices = (
session.query(TaskInstance.map_index, TaskInstance.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TaskInstance.map_index)
.all()
)

assert indices == expected


def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
mapped = MockOperator(task_id='task_2').map(arg2=xcomarg)

dr = dag_maker.create_dagrun()

session.add(
TaskMap(dag_id=dr.dag_id, task_id=task1.task_id, run_id=dr.run_id, map_index=-1, length=0, keys=None)
)

mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session)

indices = (
session.query(TaskInstance.map_index, TaskInstance.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TaskInstance.map_index)
.all()
)

assert indices == [(-1, TaskInstanceState.SKIPPED)]
18 changes: 18 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs
from tests.test_utils.mock_operators import MockOperator


class TestDagRun(unittest.TestCase):
Expand Down Expand Up @@ -874,3 +875,20 @@ def test_verify_integrity_task_start_date(Stats_incr, session, run_type, expecte
assert len(tis) == expected_tis

Stats_incr.assert_called_with('task_instance_created-DummyOperator', expected_tis)


@pytest.mark.xfail(reason="TODO: Expand mapped literals at verify_integrity time!")
def test_expand_mapped_task_instance(dag_maker, session):
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
mapped = MockOperator(task_id='task_2').map(arg2=literal)

dr = dag_maker.create_dagrun()
indices = (
session.query(TI.map_index)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TI.insert_mapping)
.all()
)

assert indices == [0, 1, 2]

0 comments on commit e45f0e9

Please sign in to comment.