Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to expand mapped tasks in to multiple "real" TIs #21019

Merged
merged 7 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

Expand Down Expand Up @@ -1800,6 +1801,73 @@ def wait_for_downstream(self) -> bool:
def depends_on_past(self) -> bool:
return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream

def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None:
"""Create the mapped TaskInstances for mapped task."""
# TODO: support having multiuple mapped upstreams?
from airflow.models.taskmap import TaskMap
from airflow.settings import task_instance_mutation_hook

task_map_info_length: Optional[int] = (
session.query(TaskMap.length)
.filter_by(
dag_id=upstream_ti.dag_id,
task_id=upstream_ti.task_id,
run_id=upstream_ti.run_id,
map_index=upstream_ti.map_index,
)
.scalar()
)
if task_map_info_length is None:
# TODO: What would lead to this? How can this be better handled?
raise RuntimeError("mapped operator cannot be expanded; upstream not found")
# TODO: Add db constraint to ensure this is never negative.

unmapped_ti: Optional[TaskInstance] = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == upstream_ti.dag_id,
TaskInstance.run_id == upstream_ti.run_id,
TaskInstance.task_id == self.task_id,
TaskInstance.map_index == -1,
TaskInstance.state.in_(State.unfinished),
)
.one_or_none()
)

if unmapped_ti:
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
if task_map_info_length < 1:
# If the upstream maps this to a zero-length value, simply marked the
# unmapped task instance as SKIPPED (if needed).
unmapped_ti.state = TaskInstanceState.SKIPPED
ashb marked this conversation as resolved.
Show resolved Hide resolved
session.merge(unmapped_ti)
return
# Otherwise convert this into the first mapped index, and create
# TaskInstance for other indexes.
unmapped_ti.map_index = 0
indexes_to_map = range(1, task_map_info_length)
else:
indexes_to_map = range(task_map_info_length)

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
# TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore
task_instance_mutation_hook(ti)
session.merge(ti)

# Set to "REMOVED" any (old) TaskInstances with map indices greater
# than the current map value
session.query(TaskInstance).filter(
TaskInstance.dag_id == upstream_ti.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == upstream_ti.run_id,
TaskInstance.map_index >= task_map_info_length,
).update({TaskInstance.state: TaskInstanceState.REMOVED})

session.flush()


# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
self.test_mode = False

@staticmethod
def insert_mapping(run_id: str, task: "BaseOperator") -> dict:
def insert_mapping(run_id: str, task: "BaseOperator", map_index: int = -1) -> dict:
""":meta private:"""
ashb marked this conversation as resolved.
Show resolved Hide resolved
return {
'dag_id': task.dag_id,
Expand All @@ -503,6 +503,7 @@ def insert_mapping(run_id: str, task: "BaseOperator") -> dict:
'max_tries': task.retries,
'executor_config': task.executor_config,
'operator': task.task_type,
'map_index': map_index,
}

@reconstructor
Expand Down
66 changes: 64 additions & 2 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@
chain,
cross_downstream,
)
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.utils.context import Context
from airflow.utils.edgemodifier import Label
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
Expand Down Expand Up @@ -733,8 +737,6 @@ def test_map_unknown_arg_raises():

def test_map_xcom_arg():
"""Test that dependencies are correct when mapping with an XComArg"""
from airflow.models.xcom_arg import XComArg

with DAG("test-dag", start_date=DEFAULT_DATE):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
Expand Down Expand Up @@ -767,3 +769,63 @@ def test_partial_on_class_invalid_ctor_args() -> None:
"""
with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"):
MockOperator.partial(task_id='a', foo='bar', bar=2)


@pytest.mark.parametrize(
["num_existing_tis", "expected"],
(
pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'),
pytest.param(
5,
[(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)],
id="tis-to-be-remove",
),
),
)
def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected):
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcomarg = XComArg(task1, "test_key")
mapped = MockOperator(task_id='task_2').map(arg2=xcomarg)

dr = dag_maker.create_dagrun()

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=len(literal),
keys=None,
)
)

if num_existing_tis:
# Remove the map_index=-1 TI when we're creating other TIs
session.query(TaskInstance).filter(
TaskInstance.dag_id == mapped.dag_id,
TaskInstance.task_id == mapped.task_id,
TaskInstance.run_id == dr.run_id,
).delete()

for index in range(num_existing_tis):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index)
session.add(ti)
session.flush()

mapped.expand_mapped_task(
upstream_ti=dr.get_task_instance(task1.task_id),
session=session,
)

indices = (
session.query(TaskInstance.map_index, TaskInstance.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TaskInstance.map_index)
.all()
)

assert indices == expected
18 changes: 18 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs
from tests.test_utils.mock_operators import MockOperator


class TestDagRun(unittest.TestCase):
Expand Down Expand Up @@ -874,3 +875,20 @@ def test_verify_integrity_task_start_date(Stats_incr, session, run_type, expecte
assert len(tis) == expected_tis

Stats_incr.assert_called_with('task_instance_created-DummyOperator', expected_tis)


@pytest.mark.xfail(reason="TODO: Expand mapped literals at verify_integrity time!")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is actually a good idea -- although we could put it here, that puts more work in the core scheduler loop so I think we could reasonably delay this to the mini scheduler in upstream task.

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we can only have it in the mini scheduler, as that can be turned off. Maybe toggle where it is run based on the mini scheduler being on or off?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't only be in the mini scheduler run, there will still be a "expansion of last resort" in the scheduler. I guess the difference is do we want to do the expansion eagerly at DagRun creation time, when it could possibly be done in another process (the LocalTaskJob).

It's probably going to be quite rare in practice that maps will be literals, so I think it's not even the cost to check this here, given that it's so unlikely it will do anything useful.

def test_expand_mapped_task_instance(dag_maker, session):
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
mapped = MockOperator(task_id='task_2').map(arg2=literal)

dr = dag_maker.create_dagrun()
indices = (
session.query(TI.map_index)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
.order_by(TI.insert_mapping)
.all()
)

assert indices == [0, 1, 2]