diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a27579b05e75..0745155cea9a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2671,14 +2671,6 @@ def signal_handler(signum, frame): previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session ) - def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: - """Render named map index if the DAG author defined map_index_template at the task level.""" - if jinja_env is None or (template := context.get("map_index_template")) is None: - return None - rendered_map_index = jinja_env.from_string(template).render(context) - log.debug("Map index rendered as %s", rendered_map_index) - return rendered_map_index - # Execute the task. with set_current_context(context): try: @@ -2686,10 +2678,10 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) except Exception: # If the task failed, swallow rendering error so it doesn't mask the main error. with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): - self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env) raise else: # If the task succeeded, render normally to let rendering error bubble up. - self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env) # Run post_execute callback self.task.post_execute(context=context, result=result) @@ -2699,6 +2691,18 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) Stats.incr("ti_successes", tags=self.stats_tags) + def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if TYPE_CHECKING: + assert self.task + + if jinja_env is None or (template := context.get("map_index_template")) is None: + return None + + rendered_map_index = self.task.render_template(template, context, jinja_env) + self.log.debug("Map index rendered as %s", rendered_map_index) + return rendered_map_index + def _execute_task(self, context: Context, task_orig: Operator): """ Execute Task (optionally with a Timeout) and push Xcom results. @@ -2805,11 +2809,23 @@ def dry_run(self) -> None: assert self.task self.task = self.task.prepare_for_execution() - self.render_templates() + context = self.get_template_context(ignore_param_exceptions=False) + + jinja_env = None + dag = self.task.get_dag() + if dag is not None: + jinja_env = dag.get_template_env() + if TYPE_CHECKING: assert isinstance(self.task, BaseOperator) + self.render_templates(context=context, jinja_env=jinja_env) self.task.dry_run() + self.rendered_map_index = self._render_map_index( + context, + jinja_env=jinja_env, + ) + @provide_session def _handle_reschedule( self, diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index fecde3bcb2c2..87373ca2864f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3608,6 +3608,73 @@ def raise_skip_exception(): assert State.SKIPPED == ti.state assert callback_function.called + def test_rendered_map_index_on_dry_run(self, dag_maker, session): + from airflow.utils.task_instance_session import set_current_task_instance_session + + with dag_maker(dag_id="example", session=session) as dag: + + @dag.task + def emit(): + return ["a", "b"] + + @dag.task(map_index_template="instance-map-index-{{ task.op_kwargs['value'] * 3 }}") + def example(value: str): + # Whatever lies here won't be executed as we trigger a `dry_run()` only. + assert value + + example.expand(value=emit()) + + dag_run = dag_maker.create_dagrun() + emit_ti = dag_run.get_task_instance("emit", session=session) + emit_ti.refresh_from_task(dag.get_task("emit")) + emit_ti.run() + + example_task = dag.get_task("example") + mapped_tis, _ = example_task.expand_mapped_task(dag_run.run_id, session=session) + + rendered_map_index = [] + with set_current_task_instance_session(session): + for ti in mapped_tis: + ti.task = example_task + ti.refresh_from_task(example_task) + ti.dry_run() + rendered_map_index.append(ti.rendered_map_index) + + assert sorted(rendered_map_index) == ["instance-map-index-aaa", "instance-map-index-bbb"] + + def test_rendered_map_index_example_from_doc(self, dag_maker, session): + from airflow.operators.python import get_current_context + + with dag_maker(dag_id="example", session=session) as dag: + + @dag.task + def emit(): + return ["a", "b"] + + @dag.task(map_index_template="{{ my_variable }}") + def my_task(my_value: str): + context = get_current_context() + context["my_variable"] = my_value * 3 # type: ignore[typeddict-unknown-key] + + my_task.expand(my_value=emit()) + + dag_run = dag_maker.create_dagrun() + emit_ti = dag_run.get_task_instance("emit", session=session) + emit_ti.refresh_from_task(dag.get_task("emit")) + emit_ti.run() + + my_task = dag.get_task("my_task") + mapped_tis, _ = my_task.expand_mapped_task(dag_run.run_id, session=session) + + rendered_map_index = [] + for ti in mapped_tis: + ti.task = my_task + ti.refresh_from_task(my_task) + ti.run() + rendered_map_index.append(ti.rendered_map_index) + + assert sorted(rendered_map_index) == ["aaa", "bbb"] + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) @@ -4247,7 +4314,8 @@ def show(value): for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): ti.refresh_from_task(show_task) - ti.run() + ti.dry_run() + assert outputs == expected_outputs def test_map_product(self, dag_maker, session):