diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index dbe63b4e9c962..d8a1bcef358c3 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -50,6 +50,7 @@ import jinja2 import pendulum from dateutil.relativedelta import relativedelta +from sqlalchemy import or_ from sqlalchemy.orm import Session from sqlalchemy.orm.exc import NoResultFound @@ -75,6 +76,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -1800,6 +1802,73 @@ def wait_for_downstream(self) -> bool: def depends_on_past(self) -> bool: return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream + def expand_mapped_task(self, upstream_ti: "TaskInstance", session: "Session" = NEW_SESSION) -> None: + """Create the mapped TaskInstances for mapped task.""" + # TODO: support having multiuple mapped upstreams? + from airflow.models.taskmap import TaskMap + from airflow.settings import task_instance_mutation_hook + + task_map_info_length: Optional[int] = ( + session.query(TaskMap.length) + .filter_by( + dag_id=upstream_ti.dag_id, + task_id=upstream_ti.task_id, + run_id=upstream_ti.run_id, + map_index=upstream_ti.map_index, + ) + .scalar() + ) + if task_map_info_length is None: + # TODO: What would lead to this? How can this be better handled? + raise RuntimeError("mapped operator cannot be expanded; upstream not found") + + unmapped_ti: Optional[TaskInstance] = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == upstream_ti.dag_id, + TaskInstance.run_id == upstream_ti.run_id, + TaskInstance.task_id == self.task_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + .one_or_none() + ) + + if unmapped_ti: + # The unmapped task instance still exists and is unfinished, i.e. we + # haven't tried to run it before. + if task_map_info_length < 1: + # If the upstream maps this to a zero-length value, simply marked the + # unmapped task instance as SKIPPED (if needed). + self.log.info("Marking %s as SKIPPED since the map has 0 values to expand", unmapped_ti) + unmapped_ti.state = TaskInstanceState.SKIPPED + session.flush() + return + # Otherwise convert this into the first mapped index, and create + # TaskInstance for other indexes. + unmapped_ti.map_index = 0 + indexes_to_map = range(1, task_map_info_length) + else: + indexes_to_map = range(task_map_info_length) + + for index in indexes_to_map: + # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. + # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator + ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index) # type: ignore + task_instance_mutation_hook(ti) + session.merge(ti) + + # Set to "REMOVED" any (old) TaskInstances with map indices greater + # than the current map value + session.query(TaskInstance).filter( + TaskInstance.dag_id == upstream_ti.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == upstream_ti.run_id, + TaskInstance.map_index >= task_map_info_length, + ).update({TaskInstance.state: TaskInstanceState.REMOVED}) + + session.flush() + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 9f47d38c3bed0..db1f71947d804 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -813,7 +813,7 @@ def task_filter(task: "BaseOperator"): def create_ti_mapping(task: "BaseOperator"): created_counts[task.task_type] += 1 - return TI.insert_mapping(self.run_id, task) + return TI.insert_mapping(self.run_id, task, map_index=-1) else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 36cda42bf0de1..0189ebe85f6c2 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -487,7 +487,7 @@ def __init__( self.test_mode = False @staticmethod - def insert_mapping(run_id: str, task: "BaseOperator") -> dict: + def insert_mapping(run_id: str, task: "BaseOperator", map_index: int) -> dict: """:meta private:""" return { 'dag_id': task.dag_id, @@ -504,6 +504,7 @@ def insert_mapping(run_id: str, task: "BaseOperator") -> dict: 'max_tries': task.retries, 'executor_config': task.executor_config, 'operator': task.task_type, + 'map_index': map_index, } @reconstructor diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index ee79aa9a75bdb..a78ae4d9f5586 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -36,8 +36,12 @@ chain, cross_downstream, ) +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap +from airflow.models.xcom_arg import XComArg from airflow.utils.context import Context from airflow.utils.edgemodifier import Label +from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -733,8 +737,6 @@ def test_map_unknown_arg_raises(): def test_map_xcom_arg(): """Test that dependencies are correct when mapping with an XComArg""" - from airflow.models.xcom_arg import XComArg - with DAG("test-dag", start_date=DEFAULT_DATE): task1 = BaseOperator(task_id="op1") xcomarg = XComArg(task1, "test_key") @@ -767,3 +769,84 @@ def test_partial_on_class_invalid_ctor_args() -> None: """ with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): MockOperator.partial(task_id='a', foo='bar', bar=2) + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), + pytest.param(3, [(0, None), (1, None), (2, None)], id='all-tis-exist'), + pytest.param( + 5, + [(0, None), (1, None), (2, None), (3, TaskInstanceState.REMOVED), (4, TaskInstanceState.REMOVED)], + id="tis-to-be-remove", + ), + ), +) +def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [1, 2, {'a': 'b'}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + xcomarg = XComArg(task1, "test_key") + mapped = MockOperator(task_id='task_2').map(arg2=xcomarg) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index) # type: ignore + session.add(ti) + session.flush() + + mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + xcomarg = XComArg(task1, "test_key") + mapped = MockOperator(task_id='task_2').map(arg2=xcomarg) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap(dag_id=dr.dag_id, task_id=task1.task_id, run_id=dr.run_id, map_index=-1, length=0, keys=None) + ) + + mapped.expand_mapped_task(upstream_ti=dr.get_task_instance(task1.task_id), session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == [(-1, TaskInstanceState.SKIPPED)] diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 56f214009f111..c5f048554b0dc 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -39,6 +39,7 @@ from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs +from tests.test_utils.mock_operators import MockOperator class TestDagRun(unittest.TestCase): @@ -874,3 +875,20 @@ def test_verify_integrity_task_start_date(Stats_incr, session, run_type, expecte assert len(tis) == expected_tis Stats_incr.assert_called_with('task_instance_created-DummyOperator', expected_tis) + + +@pytest.mark.xfail(reason="TODO: Expand mapped literals at verify_integrity time!") +def test_expand_mapped_task_instance(dag_maker, session): + literal = [1, 2, {'a': 'b'}] + with dag_maker(session=session): + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + dr = dag_maker.create_dagrun() + indices = ( + session.query(TI.map_index) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TI.insert_mapping) + .all() + ) + + assert indices == [0, 1, 2]