diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index cd8287be97b73..c3a8be3134ced 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -75,6 +75,7 @@ class TITerminalStatePayload(StrictBaseModel): end_date: UtcDateTime """When the task completed executing""" + rendered_map_index: str | None = None class TISuccessStatePayload(StrictBaseModel): @@ -97,6 +98,7 @@ class TISuccessStatePayload(StrictBaseModel): task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)] outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)] + rendered_map_index: str | None = None class TITargetStatePayload(StrictBaseModel): @@ -136,6 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel): Both forms will be passed along to the TaskSDK upon resume, the server will not handle either. """ + rendered_map_index: str | None = None class TIRescheduleStatePayload(StrictBaseModel): @@ -171,6 +174,7 @@ class TIRetryStatePayload(StrictBaseModel): ), ] end_date: UtcDateTime + rendered_map_index: str | None = None class TISkippedDownstreamTasksStatePayload(StrictBaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 0d3a225305b67..54329054c1213 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -19,7 +19,10 @@ from cadwyn import HeadVersion, Version, VersionBundle +from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField + bundle = VersionBundle( HeadVersion(), + Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py new file mode 100644 index 0000000000000..e0916b4c93d67 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + TIDeferredStatePayload, + TIRetryStatePayload, + TISuccessStatePayload, + TITerminalStatePayload, +) + + +class AddRenderedMapIndexField(VersionChange): + """Add the `rendered_map_index` field to payload models.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TITerminalStatePayload).field("rendered_map_index").didnt_exist, + schema(TISuccessStatePayload).field("rendered_map_index").didnt_exist, + schema(TIDeferredStatePayload).field("rendered_map_index").didnt_exist, + schema(TIRetryStatePayload).field("rendered_map_index").didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index adaa93bc202af..32c53ae0db9cb 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -26,7 +26,7 @@ def test_custom_openapi_includes_extra_schemas(client): """Test to ensure that extra schemas are correctly included in the OpenAPI schema.""" - response = client.get("/execution/openapi.json?version=2025-04-11") + response = client.get("/execution/openapi.json?version=2025-04-28") assert response.status_code == 200 openapi_schema = response.json() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 8af11c04df960..41bba2c2bd522 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -44,6 +44,7 @@ DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z") DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z") +DEFAULT_RENDERED_MAP_INDEX = "test rendered map index" def _create_asset_aliases(session, num: int = 2) -> None: @@ -465,6 +466,39 @@ def test_ti_update_state_to_terminal( assert ti.state == expected_state assert ti.end_date == end_date + @pytest.mark.parametrize( + ("state", "end_date", "expected_state", "rendered_map_index"), + [ + (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS, DEFAULT_RENDERED_MAP_INDEX), + (State.FAILED, DEFAULT_END_DATE, State.FAILED, DEFAULT_RENDERED_MAP_INDEX), + (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED, DEFAULT_RENDERED_MAP_INDEX), + ], + ) + def test_ti_update_state_to_terminal_with_rendered_map_index( + self, client, session, create_task_instance, state, end_date, expected_state, rendered_map_index + ): + ti = create_task_instance( + task_id="test_ti_update_state_to_terminal_with_rendered_map_index", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={"state": state, "end_date": end_date.isoformat(), "rendered_map_index": rendered_map_index}, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == expected_state + assert ti.end_date == end_date + assert ti.rendered_map_index == rendered_map_index + @pytest.mark.parametrize( "task_outlets", [ diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 46e4fe6907a73..5f76e8360be7d 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -146,22 +146,29 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime): + def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index): """Tell the API server that this TI has reached a terminal state.""" if state == TaskInstanceState.SUCCESS: raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. - body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state)) + body = TITerminalStatePayload( + end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index + ) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def retry(self, id: uuid.UUID, end_date: datetime): + def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): """Tell the API server that this TI has failed and reached a up_for_retry state.""" - body = TIRetryStatePayload(end_date=end_date) + body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events): + def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index): """Tell the API server that this TI has succeeded.""" - body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events) + body = TISuccessStatePayload( + end_date=when, + task_outlets=task_outlets, + outlet_events=outlet_events, + rendered_map_index=rendered_map_index, + ) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def heartbeat(self, id: uuid.UUID, pid: int): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 7c7647635ea05..a477c3ffc410f 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue -API_VERSION: Final[str] = "2025-04-11" +API_VERSION: Final[str] = "2025-04-28" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -193,6 +193,7 @@ class TIDeferredStatePayload(BaseModel): trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None next_method: Annotated[str, Field(title="Next Method")] next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TIEnterRunningPayload(BaseModel): @@ -245,6 +246,7 @@ class TIRetryStatePayload(BaseModel): ) state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry" end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TISkippedDownstreamTasksStatePayload(BaseModel): @@ -270,6 +272,7 @@ class TISuccessStatePayload(BaseModel): end_date: Annotated[AwareDatetime, Field(title="End Date")] task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TITargetStatePayload(BaseModel): @@ -494,3 +497,4 @@ class TITerminalStatePayload(BaseModel): ) state: TerminalStateNonSuccess end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 699fe5c0370bf..a25ba5745827d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -376,6 +376,7 @@ class TaskState(BaseModel): ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" + rendered_map_index: str | None = None class SucceedTask(TISuccessStatePayload): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 001f8deb701d4..b5cf977488b71 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -767,6 +767,7 @@ class ActivitySubprocess(WatchedSubprocess): # TODO: This should come from airflow.cfg: [core] task_success_overtime TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0 _task_end_time_monotonic: float | None = attrs.field(default=None, init=False) + _rendered_map_index: str | None = attrs.field(default=None, init=False) decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor) @@ -842,7 +843,10 @@ def wait(self) -> int: # by the subprocess in the `handle_requests` method. if self.final_state not in STATES_SENT_DIRECTLY: self.client.task_instances.finish( - id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) + id=self.id, + state=self.final_state, + when=datetime.now(tz=timezone.utc), + rendered_map_index=self._rendered_map_index, ) # Now at the last possible moment, when all logs and comms with the subprocess has finished, lets @@ -988,21 +992,26 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = msg.rendered_map_index elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.succeed( id=self.id, when=msg.end_date, task_outlets=msg.task_outlets, outlet_events=msg.outlet_events, + rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, RetryTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.retry( id=self.id, end_date=msg.end_date, + rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) @@ -1045,6 +1054,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = xcom elif isinstance(msg, DeferTask): self._terminal_state = TaskInstanceState.DEFERRED + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.defer(self.id, msg) elif isinstance(msg, RescheduleTask): self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6e1e5f885f5e1..9092ee86f0b15 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -19,6 +19,7 @@ from __future__ import annotations +import contextlib import contextvars import functools import os @@ -134,6 +135,8 @@ class RuntimeTaskInstance(TaskInstance): is_mapped: bool | None = None """True if the original task was mapped.""" + rendered_map_index: str | None = None + def __rich_repr__(self): yield "id", self.id yield "task_id", self.task_id @@ -831,7 +834,17 @@ def run( ti.state = state = TaskInstanceState.FAILED return state, msg, error - result = _execute_task(context, ti, log) + try: + result = _execute_task(context, ti, log) + except Exception: + import jinja2 + + # If the task failed, swallow rendering error so it doesn't mask the main error. + with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): + ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) + raise + else: # If the task succeeded, render normally to let rendering error bubble up. + ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) _push_xcom_if_needed(result, ti, log) @@ -851,6 +864,7 @@ def run( msg = TaskState( state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.SKIPPED except AirflowRescheduleException as reschedule: @@ -868,6 +882,7 @@ def run( msg = TaskState( state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED error = e @@ -884,6 +899,7 @@ def run( msg = TaskState( state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED error = e @@ -915,6 +931,7 @@ def _handle_current_task_success( end_date=datetime.now(tz=timezone.utc), task_outlets=task_outlets, outlet_events=outlet_events, + rendered_map_index=ti.rendered_map_index, ) return msg, TaskInstanceState.SUCCESS @@ -925,7 +942,9 @@ def _handle_current_task_failed( end_date = datetime.now(tz=timezone.utc) if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY - return TaskState(state=TaskInstanceState.FAILED, end_date=end_date), TaskInstanceState.FAILED + return TaskState( + state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index + ), TaskInstanceState.FAILED def _handle_trigger_dag_run( @@ -951,11 +970,19 @@ def _handle_trigger_dag_run( "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", dag_id=drte.trigger_dag_id, ) - msg = TaskState(state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState( + state=TaskInstanceState.SKIPPED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.SKIPPED else: log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState( + state=TaskInstanceState.FAILED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.FAILED return msg, state @@ -1001,7 +1028,11 @@ def _handle_trigger_dag_run( log.error( "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state ) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState( + state=TaskInstanceState.FAILED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.FAILED return msg, state if comms_msg.state in drte.allowed_states: @@ -1106,6 +1137,16 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): return result +def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if (template := context.get("map_index_template")) is None: + return None + jinja_env = ti.task.dag.get_template_env() + rendered_map_index = jinja_env.from_string(template).render(context) + log.info("Map index rendered as %s", rendered_map_index) + return rendered_map_index + + def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" if ti.task.do_xcom_push: diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index e1f678bb1f77f..ce5329197493e 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -295,13 +295,16 @@ def handle_request(request: httpx.Request) -> httpx.Response: actual_body = json.loads(request.read()) assert actual_body["end_date"] == "2024-10-31T12:00:00Z" assert actual_body["state"] == state + assert actual_body["rendered_map_index"] == "test" return httpx.Response( status_code=204, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z") + client.task_instances.finish( + ti_id, state=state, when="2024-10-31T12:00:00Z", rendered_map_index="test" + ) def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti @@ -383,13 +386,16 @@ def handle_request(request: httpx.Request) -> httpx.Response: actual_body = json.loads(request.read()) assert actual_body["state"] == "up_for_retry" assert actual_body["end_date"] == "2024-10-31T12:00:00Z" + assert actual_body["rendered_map_index"] == "test" return httpx.Response( status_code=204, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.retry(ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z")) + client.task_instances.retry( + ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test" + ) @pytest.mark.parametrize( "rendered_fields", diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 0795e1155c985..5690f9b418ec6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1193,11 +1193,17 @@ def watched_subprocess(self, mocker): id="patch_task_instance_to_skipped", ), pytest.param( - RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + RetryTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" + ), b"", "task_instances.retry", (), - {"id": TI_ID, "end_date": timezone.parse("2024-10-31T12:00:00Z")}, + { + "id": TI_ID, + "end_date": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": "test retry task", + }, "", id="up_for_retry", ), @@ -1317,7 +1323,9 @@ def watched_subprocess(self, mocker): id="get_asset_events_by_asset_alias", ), pytest.param( - SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + SucceedTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" + ), b"", "task_instances.succeed", (), @@ -1326,6 +1334,7 @@ def watched_subprocess(self, mocker): "outlet_events": None, "task_outlets": None, "when": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": "test success task", }, "", id="succeed_task", diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f9ad46f357e7c..8df65b88ebf48 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -970,6 +970,37 @@ def test_function(): assert os.environ["AIRFLOW_CTX_TASK_ID"] == "test_env_task" +def test_execute_success_task_with_rendered_map_index(create_runtime_ti, mock_supervisor_comms): + """Test that the map index is rendered in the task context.""" + + def test_function(): + return "test function" + + task = PythonOperator( + task_id="test_task", + python_callable=test_function, + map_index_template="Hello! {{ run_id }}", + ) + + ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template") + + run(ti, ti.get_template_context(), log=mock.MagicMock()) + + assert ti.rendered_map_index == "Hello! test_run" + + +def test_execute_failed_task_with_rendered_map_index(create_runtime_ti, mock_supervisor_comms): + """Test that the map index is rendered in the task context.""" + + task = BaseOperator(task_id="test_task", map_index_template="Hello! {{ run_id }}") + + ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template") + + run(ti, ti.get_template_context(), log=mock.MagicMock()) + + assert ti.rendered_map_index == "Hello! test_run" + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server."""