Skip to content

Commit

Permalink
Fix short circuit in mapped tasks (#44912)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 authored Dec 18, 2024
1 parent 878aab3 commit ec4db3e
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 45 deletions.
2 changes: 2 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
op.is_setup = is_setup
op.is_teardown = is_teardown
op.on_failure_fail_dagrun = on_failure_fail_dagrun
op.downstream_task_ids = self.downstream_task_ids
op.upstream_task_ids = self.upstream_task_ids
return op

# After a mapped operator is serialized, there's no real way to actually
Expand Down
9 changes: 6 additions & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ def _skip(
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# The following could be applied only for non-mapped tasks
if map_index == -1:
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

if task_id is not None:
from airflow.models.xcom import XCom
Expand All @@ -177,8 +180,8 @@ def _skip(
session=session,
)

@staticmethod
def skip_all_except(
self,
ti: TaskInstance | TaskInstancePydantic,
branch_task_ids: None | str | Iterable[str],
):
Expand Down
79 changes: 40 additions & 39 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.db import LazySelectSequence


class NotPreviouslySkippedDep(BaseTIDep):
Expand All @@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
XCOM_SKIPMIXIN_FOLLOWED,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
from airflow.utils.state import TaskInstanceState

Expand All @@ -49,46 +49,47 @@ def _get_dep_statuses(self, ti, session, dep_context):
finished_task_ids = {t.task_id for t in finished_tis}

for parent in upstream:
if isinstance(parent, SkipMixin):
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
prev_result = ti.xcom_pull(
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session, map_indexes=ti.map_index
)

if prev_result is None:
# This can happen if the parent task has not yet run.
continue
if isinstance(prev_result, LazySelectSequence):
prev_result = next(iter(prev_result))

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True
if prev_result is None:
# This can happen if the parent task has not yet run.
continue

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]:
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
return
if not past_depends_met:
yield self._failing_status(
reason="Task should be skipped but the past depends are not met"
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
)
return
1 change: 1 addition & 0 deletions newsfragments/44912.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix short circuit operator in mapped tasks. The operator did not work until now due to a bug in ``NotPreviouslySkippedDep``. Please note that at time of merging, this fix has been applied only for Airflow version > 2.10.4 and < 3, and should be ported to v3 after merging PR #44925.
57 changes: 56 additions & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pendulum
import pytest
Expand Down Expand Up @@ -1868,3 +1868,58 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
assert tis["group.last", 0].state == State.SUCCESS
assert dr.get_task_instance("group.last", map_index=1, session=session).state == State.SKIPPED
assert tis["group.last", 2].state == State.SUCCESS


class TestMappedOperator:
@pytest.fixture
def mock_operator_class(self):
return MagicMock(spec=type(BaseOperator))

@pytest.fixture
@patch("airflow.serialization.serialized_objects.SerializedBaseOperator")
def mapped_operator(self, _, mock_operator_class):
return MappedOperator(
operator_class=mock_operator_class,
expand_input=MagicMock(),
partial_kwargs={"task_id": "test_task"},
task_id="test_task",
params={},
deps=frozenset(),
operator_extra_links=[],
template_ext=[],
template_fields=[],
template_fields_renderers={},
ui_color="",
ui_fgcolor="",
start_trigger_args=None,
start_from_trigger=False,
dag=None,
task_group=None,
start_date=None,
end_date=None,
is_empty=False,
task_module=MagicMock(),
task_type="taske_type",
operator_name="operator_name",
disallow_kwargs_override=False,
expand_input_attr="expand_input",
)

def test_unmap_with_resolved_kwargs(self, mapped_operator, mock_operator_class):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
resolve = {"param1": "value1"}
result = mapped_operator.unmap(resolve)
assert result == mock_operator_class.return_value
assert result.task_id == "test_task"
assert result.is_setup is False
assert result.is_teardown is False
assert result.on_failure_fail_dagrun is False
assert result.upstream_task_ids == ["a"]
assert result.downstream_task_ids == ["b"]

def test_unmap_runtime_error(self, mapped_operator):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
with pytest.raises(RuntimeError):
mapped_operator.unmap(None)
39 changes: 37 additions & 2 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from __future__ import annotations

import datetime
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pendulum
import pytest

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models import DagRun, MappedOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.empty import EmptyOperator
Expand Down Expand Up @@ -53,6 +54,10 @@ def setup_method(self):
def teardown_method(self):
self.clean_db()

@pytest.fixture
def mock_session(self):
return Mock(spec=settings.Session)

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@patch("airflow.utils.timezone.utcnow")
def test_skip(self, mock_now, dag_maker):
Expand Down Expand Up @@ -104,10 +109,40 @@ def test_skip_none_dagrun(self, mock_now, dag_maker):

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[])
assert (
SkipMixin()._skip(dag_run=None, task_id=None, execution_date=None, tasks=[], session=session)
is None
)
assert not session.query.called
assert not session.commit.called

def test_skip_mapped_task(self, mock_session):
SkipMixin()._skip(
dag_run=MagicMock(spec=DagRun),
task_id=None,
execution_date=None,
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=2,
)
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()

@patch("airflow.models.skipmixin.update")
def test_skip_none_mapped_task(self, mock_update, mock_session):
SkipMixin()._skip(
dag_run=MagicMock(spec=DagRun),
task_id=None,
execution_date=None,
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=-1,
)
mock_session.execute.assert_called_once_with(
mock_update.return_value.where.return_value.values.return_value.execution_options.return_value
)
mock_session.commit.assert_called_once()

@pytest.mark.parametrize(
"branch_task_ids, expected_states",
[
Expand Down
45 changes: 45 additions & 0 deletions tests/ti_deps/deps/test_not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pendulum
import pytest

from airflow.decorators import task
from airflow.models import DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import BranchPythonOperator
Expand Down Expand Up @@ -84,6 +85,50 @@ def test_no_skipmixin_parent(session, dag_maker):
assert ti2.state != State.SKIPPED


@pytest.mark.parametrize("condition, final_state", [(True, State.SUCCESS), (False, State.SKIPPED)])
def test_parent_is_mapped_short_circuit(session, dag_maker, condition, final_state):
with dag_maker(session=session):

@task
def op1():
return [1]

@task.short_circuit
def op2(i: int):
return condition

@task
def op3(res: bool):
pass

op3.expand(res=op2.expand(i=op1()))

dr = dag_maker.create_dagrun()

def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
decision = dr.task_instance_scheduling_decisions(session=session)
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}

tis = _one_scheduling_decision_iteration()

tis["op1", -1].run()
assert tis["op1", -1].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["op2", 0].run()

assert tis["op2", 0].state == State.SUCCESS
tis = _one_scheduling_decision_iteration()

if condition:
ti3 = tis["op3", 0]
ti3.run()
else:
ti3 = dr.get_task_instance("op3", map_index=0, session=session)

assert ti3.state == final_state


def test_parent_follow_branch(session, dag_maker):
"""
A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.
Expand Down

0 comments on commit ec4db3e

Please sign in to comment.