From 0debe662bc5417b53ceefe1e06d83ff7269e0453 Mon Sep 17 00:00:00 2001 From: Purna Chander Date: Sat, 26 Apr 2025 00:22:19 +0530 Subject: [PATCH 1/5] rendered map index on task state change #49224 --- .../execution_api/datamodels/taskinstance.py | 4 +++ .../example_params_ui_tutorial.py | 0 .../src/airflow/jobs/triggerer_job_runner.py | 0 .../src/airflow/provider_info.schema.json | 0 .../ui/openapi-gen/queries/prefetch.ts | 0 airflow-core/src/airflow/ui/package.json | 0 .../components/DagActions/RunBackfillForm.tsx | 0 .../test_docker_compose_quick_start.py | 0 task-sdk/src/airflow/sdk/api/client.py | 14 ++++---- .../airflow/sdk/api/datamodels/_generated.py | 5 +++ .../src/airflow/sdk/execution_time/comms.py | 1 + .../airflow/sdk/execution_time/supervisor.py | 11 ++++++- .../airflow/sdk/execution_time/task_runner.py | 32 +++++++++++++++++-- 13 files changed, 57 insertions(+), 10 deletions(-) mode change 100644 => 100755 airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py mode change 100644 => 100755 airflow-core/src/airflow/jobs/triggerer_job_runner.py mode change 100644 => 100755 airflow-core/src/airflow/provider_info.schema.json mode change 100644 => 100755 airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts mode change 100644 => 100755 airflow-core/src/airflow/ui/package.json mode change 100644 => 100755 airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx mode change 100644 => 100755 docker-tests/tests/docker_tests/test_docker_compose_quick_start.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index cd8287be97b73..6fd65b185e00d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -75,6 +75,7 @@ class TITerminalStatePayload(StrictBaseModel): end_date: UtcDateTime """When the task completed executing""" + rendered_map_index: Annotated[str | None, Field(default_factory=str)] class TISuccessStatePayload(StrictBaseModel): @@ -97,6 +98,7 @@ class TISuccessStatePayload(StrictBaseModel): task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)] outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)] + rendered_map_index: Annotated[str | None, Field(default_factory=str)] class TITargetStatePayload(StrictBaseModel): @@ -136,6 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel): Both forms will be passed along to the TaskSDK upon resume, the server will not handle either. """ + rendered_map_index: Annotated[str | None, Field(default_factory=str)] class TIRescheduleStatePayload(StrictBaseModel): @@ -171,6 +174,7 @@ class TIRetryStatePayload(StrictBaseModel): ), ] end_date: UtcDateTime + rendered_map_index: Annotated[str | None, Field(default_factory=str)] class TISkippedDownstreamTasksStatePayload(StrictBaseModel): diff --git a/airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py b/airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py old mode 100644 new mode 100755 diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py old mode 100644 new mode 100755 diff --git a/airflow-core/src/airflow/provider_info.schema.json b/airflow-core/src/airflow/provider_info.schema.json old mode 100644 new mode 100755 diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts old mode 100644 new mode 100755 diff --git a/airflow-core/src/airflow/ui/package.json b/airflow-core/src/airflow/ui/package.json old mode 100644 new mode 100755 diff --git a/airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx b/airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx old mode 100644 new mode 100755 diff --git a/docker-tests/tests/docker_tests/test_docker_compose_quick_start.py b/docker-tests/tests/docker_tests/test_docker_compose_quick_start.py old mode 100644 new mode 100755 diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 46e4fe6907a73..99262d6476ebb 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -146,22 +146,24 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime): + def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index): """Tell the API server that this TI has reached a terminal state.""" if state == TaskInstanceState.SUCCESS: raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. - body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state)) + body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state), + rendered_map_index=rendered_map_index) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def retry(self, id: uuid.UUID, end_date: datetime): + def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): """Tell the API server that this TI has failed and reached a up_for_retry state.""" - body = TIRetryStatePayload(end_date=end_date) + body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events): + def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index): """Tell the API server that this TI has succeeded.""" - body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events) + body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events, + rendered_map_index=rendered_map_index) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def heartbeat(self, id: uuid.UUID, pid: int): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 7c7647635ea05..d96be1d7d8f60 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -193,6 +193,7 @@ class TIDeferredStatePayload(BaseModel): trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None next_method: Annotated[str, Field(title="Next Method")] next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TIEnterRunningPayload(BaseModel): @@ -233,6 +234,7 @@ class TIRescheduleStatePayload(BaseModel): state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule" reschedule_date: Annotated[AwareDatetime, Field(title="Reschedule Date")] end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TIRetryStatePayload(BaseModel): @@ -245,6 +247,7 @@ class TIRetryStatePayload(BaseModel): ) state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry" end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TISkippedDownstreamTasksStatePayload(BaseModel): @@ -270,6 +273,7 @@ class TISuccessStatePayload(BaseModel): end_date: Annotated[AwareDatetime, Field(title="End Date")] task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TITargetStatePayload(BaseModel): @@ -494,3 +498,4 @@ class TITerminalStatePayload(BaseModel): ) state: TerminalStateNonSuccess end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 699fe5c0370bf..a25ba5745827d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -376,6 +376,7 @@ class TaskState(BaseModel): ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" + rendered_map_index: str | None = None class SucceedTask(TISuccessStatePayload): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 001f8deb701d4..29064a157a3f1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -767,6 +767,7 @@ class ActivitySubprocess(WatchedSubprocess): # TODO: This should come from airflow.cfg: [core] task_success_overtime TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0 _task_end_time_monotonic: float | None = attrs.field(default=None, init=False) + _rendered_map_index: str | None = attrs.field(default=None, init=False) decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor) @@ -842,7 +843,8 @@ def wait(self) -> int: # by the subprocess in the `handle_requests` method. if self.final_state not in STATES_SENT_DIRECTLY: self.client.task_instances.finish( - id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) + id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), + rendered_map_index=self._rendered_map_index ) # Now at the last possible moment, when all logs and comms with the subprocess has finished, lets @@ -988,21 +990,26 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = getattr(msg, "rendered_map_index", None) elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = getattr(msg, "rendered_map_index", None) self.client.task_instances.succeed( id=self.id, when=msg.end_date, task_outlets=msg.task_outlets, outlet_events=msg.outlet_events, + rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, RetryTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + self._rendered_map_index = getattr(msg, "rendered_map_index", None) self.client.task_instances.retry( id=self.id, end_date=msg.end_date, + rendered_map_index=self._rendered_map_index, ) elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) @@ -1045,8 +1052,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = xcom elif isinstance(msg, DeferTask): self._terminal_state = TaskInstanceState.DEFERRED + self._rendered_map_index = getattr(msg, "rendered_map_index", None) self.client.task_instances.defer(self.id, msg) elif isinstance(msg, RescheduleTask): + self._rendered_map_index = getattr(msg, "rendered_map_index", None) self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE self.client.task_instances.reschedule(self.id, msg) elif isinstance(msg, SkipDownstreamTasks): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6e1e5f885f5e1..d96ba0ccdadaf 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -19,6 +19,7 @@ from __future__ import annotations +import contextlib import contextvars import functools import os @@ -134,6 +135,8 @@ class RuntimeTaskInstance(TaskInstance): is_mapped: bool | None = None """True if the original task was mapped.""" + rendered_map_index: str | None = None + def __rich_repr__(self): yield "id", self.id yield "task_id", self.task_id @@ -830,8 +833,17 @@ def run( msg = early_exit ti.state = state = TaskInstanceState.FAILED return state, msg, error - - result = _execute_task(context, ti, log) + jinja_env = ti.task.dag.get_template_env() + + try: + result = _execute_task(context, ti, log) + except Exception: + # If the task failed, swallow rendering error so it doesn't mask the main error. + with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): + ti.rendered_map_index = _render_map_index(context, jinja_env=jinja_env, log=log) + raise + else: # If the task succeeded, render normally to let rendering error bubble up. + ti.rendered_map_index = _render_map_index(context, jinja_env=jinja_env, log=log) _push_xcom_if_needed(result, ti, log) @@ -851,6 +863,7 @@ def run( msg = TaskState( state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index ) state = TaskInstanceState.SKIPPED except AirflowRescheduleException as reschedule: @@ -868,6 +881,7 @@ def run( msg = TaskState( state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index ) state = TaskInstanceState.FAILED error = e @@ -884,6 +898,7 @@ def run( msg = TaskState( state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED error = e @@ -915,6 +930,7 @@ def _handle_current_task_success( end_date=datetime.now(tz=timezone.utc), task_outlets=task_outlets, outlet_events=outlet_events, + rendered_map_index=ti.rendered_map_index, ) return msg, TaskInstanceState.SUCCESS @@ -925,7 +941,7 @@ def _handle_current_task_failed( end_date = datetime.now(tz=timezone.utc) if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY - return TaskState(state=TaskInstanceState.FAILED, end_date=end_date), TaskInstanceState.FAILED + return TaskState(state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index), TaskInstanceState.FAILED def _handle_trigger_dag_run( @@ -1106,6 +1122,16 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): return result +def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None, log: Logger) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if jinja_env is None or (template := context.get("map_index_template")) is None: + return None + rendered_map_index = jinja_env.from_string(template).render(context) + log.info("Map index rendered as %s", rendered_map_index) + print("Map index rendered as " + rendered_map_index) + return rendered_map_index + + def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" if ti.task.do_xcom_push: From 393d664a74a812d0efa524ffbe937fee3c9891b4 Mon Sep 17 00:00:00 2001 From: Purna Chander Date: Sat, 26 Apr 2025 14:54:32 +0530 Subject: [PATCH 2/5] test cases and code refactor #49224 --- task-sdk/src/airflow/sdk/api/client.py | 13 +++++++++---- .../airflow/sdk/api/datamodels/_generated.py | 1 - .../airflow/sdk/execution_time/supervisor.py | 6 ++++-- .../airflow/sdk/execution_time/task_runner.py | 17 +++++++++-------- task-sdk/tests/task_sdk/api/test_client.py | 10 ++++++++-- .../task_sdk/execution_time/test_supervisor.py | 3 ++- 6 files changed, 32 insertions(+), 18 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 99262d6476ebb..5f76e8360be7d 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -151,8 +151,9 @@ def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, if state == TaskInstanceState.SUCCESS: raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. - body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state), - rendered_map_index=rendered_map_index) + body = TITerminalStatePayload( + end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index + ) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): @@ -162,8 +163,12 @@ def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index): """Tell the API server that this TI has succeeded.""" - body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events, - rendered_map_index=rendered_map_index) + body = TISuccessStatePayload( + end_date=when, + task_outlets=task_outlets, + outlet_events=outlet_events, + rendered_map_index=rendered_map_index, + ) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def heartbeat(self, id: uuid.UUID, pid: int): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index d96be1d7d8f60..62eb87c4242eb 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -234,7 +234,6 @@ class TIRescheduleStatePayload(BaseModel): state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule" reschedule_date: Annotated[AwareDatetime, Field(title="Reschedule Date")] end_date: Annotated[AwareDatetime, Field(title="End Date")] - rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None class TIRetryStatePayload(BaseModel): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 29064a157a3f1..17556df66694c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -843,8 +843,10 @@ def wait(self) -> int: # by the subprocess in the `handle_requests` method. if self.final_state not in STATES_SENT_DIRECTLY: self.client.task_instances.finish( - id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), - rendered_map_index=self._rendered_map_index + id=self.id, + state=self.final_state, + when=datetime.now(tz=timezone.utc), + rendered_map_index=self._rendered_map_index, ) # Now at the last possible moment, when all logs and comms with the subprocess has finished, lets diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d96ba0ccdadaf..546d6f4203b61 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -34,6 +34,7 @@ import aiologic import attrs +import jinja2 import lazy_object_proxy import structlog from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter @@ -97,7 +98,6 @@ from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: - import jinja2 from pendulum.datetime import DateTime from structlog.typing import FilteringBoundLogger as Logger @@ -863,7 +863,7 @@ def run( msg = TaskState( state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), - rendered_map_index=ti.rendered_map_index + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.SKIPPED except AirflowRescheduleException as reschedule: @@ -881,7 +881,7 @@ def run( msg = TaskState( state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), - rendered_map_index=ti.rendered_map_index + rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED error = e @@ -941,7 +941,9 @@ def _handle_current_task_failed( end_date = datetime.now(tz=timezone.utc) if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY - return TaskState(state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index), TaskInstanceState.FAILED + return TaskState( + state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index + ), TaskInstanceState.FAILED def _handle_trigger_dag_run( @@ -967,11 +969,11 @@ def _handle_trigger_dag_run( "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", dag_id=drte.trigger_dag_id, ) - msg = TaskState(state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState(state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) state = TaskInstanceState.SKIPPED else: log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) state = TaskInstanceState.FAILED return msg, state @@ -1017,7 +1019,7 @@ def _handle_trigger_dag_run( log.error( "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state ) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc)) + msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) state = TaskInstanceState.FAILED return msg, state if comms_msg.state in drte.allowed_states: @@ -1128,7 +1130,6 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None, return None rendered_map_index = jinja_env.from_string(template).render(context) log.info("Map index rendered as %s", rendered_map_index) - print("Map index rendered as " + rendered_map_index) return rendered_map_index diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index e1f678bb1f77f..ce5329197493e 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -295,13 +295,16 @@ def handle_request(request: httpx.Request) -> httpx.Response: actual_body = json.loads(request.read()) assert actual_body["end_date"] == "2024-10-31T12:00:00Z" assert actual_body["state"] == state + assert actual_body["rendered_map_index"] == "test" return httpx.Response( status_code=204, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z") + client.task_instances.finish( + ti_id, state=state, when="2024-10-31T12:00:00Z", rendered_map_index="test" + ) def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti @@ -383,13 +386,16 @@ def handle_request(request: httpx.Request) -> httpx.Response: actual_body = json.loads(request.read()) assert actual_body["state"] == "up_for_retry" assert actual_body["end_date"] == "2024-10-31T12:00:00Z" + assert actual_body["rendered_map_index"] == "test" return httpx.Response( status_code=204, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.retry(ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z")) + client.task_instances.retry( + ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test" + ) @pytest.mark.parametrize( "rendered_fields", diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 0795e1155c985..9dd1b117057ea 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1197,7 +1197,7 @@ def watched_subprocess(self, mocker): b"", "task_instances.retry", (), - {"id": TI_ID, "end_date": timezone.parse("2024-10-31T12:00:00Z")}, + {"id": TI_ID, "end_date": timezone.parse("2024-10-31T12:00:00Z"), "rendered_map_index": None}, "", id="up_for_retry", ), @@ -1326,6 +1326,7 @@ def watched_subprocess(self, mocker): "outlet_events": None, "task_outlets": None, "when": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": None, }, "", id="succeed_task", From 184f7a2c9141bf7220f3d5619368eec8fdce4a92 Mon Sep 17 00:00:00 2001 From: Purna Chander Date: Mon, 28 Apr 2025 21:01:43 +0530 Subject: [PATCH 3/5] code refactor with Cadwyn and pr comments fix --- .../execution_api/datamodels/taskinstance.py | 8 ++-- .../execution_api/versions/__init__.py | 3 ++ .../execution_api/versions/v2025_04_28.py | 40 +++++++++++++++++++ .../api_fastapi/execution_api/test_app.py | 2 +- .../airflow/sdk/api/datamodels/_generated.py | 2 +- .../airflow/sdk/execution_time/task_runner.py | 17 +++++--- .../execution_time/test_supervisor.py | 16 ++++++-- 7 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 6fd65b185e00d..c3a8be3134ced 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -75,7 +75,7 @@ class TITerminalStatePayload(StrictBaseModel): end_date: UtcDateTime """When the task completed executing""" - rendered_map_index: Annotated[str | None, Field(default_factory=str)] + rendered_map_index: str | None = None class TISuccessStatePayload(StrictBaseModel): @@ -98,7 +98,7 @@ class TISuccessStatePayload(StrictBaseModel): task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)] outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)] - rendered_map_index: Annotated[str | None, Field(default_factory=str)] + rendered_map_index: str | None = None class TITargetStatePayload(StrictBaseModel): @@ -138,7 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel): Both forms will be passed along to the TaskSDK upon resume, the server will not handle either. """ - rendered_map_index: Annotated[str | None, Field(default_factory=str)] + rendered_map_index: str | None = None class TIRescheduleStatePayload(StrictBaseModel): @@ -174,7 +174,7 @@ class TIRetryStatePayload(StrictBaseModel): ), ] end_date: UtcDateTime - rendered_map_index: Annotated[str | None, Field(default_factory=str)] + rendered_map_index: str | None = None class TISkippedDownstreamTasksStatePayload(StrictBaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 0d3a225305b67..54329054c1213 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -19,7 +19,10 @@ from cadwyn import HeadVersion, Version, VersionBundle +from airflow.api_fastapi.execution_api.versions.v2025_04_28 import AddRenderedMapIndexField + bundle = VersionBundle( HeadVersion(), + Version("2025-04-28", AddRenderedMapIndexField), Version("2025-04-11"), ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py new file mode 100644 index 0000000000000..e0916b4c93d67 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + TIDeferredStatePayload, + TIRetryStatePayload, + TISuccessStatePayload, + TITerminalStatePayload, +) + + +class AddRenderedMapIndexField(VersionChange): + """Add the `rendered_map_index` field to payload models.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TITerminalStatePayload).field("rendered_map_index").didnt_exist, + schema(TISuccessStatePayload).field("rendered_map_index").didnt_exist, + schema(TIDeferredStatePayload).field("rendered_map_index").didnt_exist, + schema(TIRetryStatePayload).field("rendered_map_index").didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index adaa93bc202af..32c53ae0db9cb 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -26,7 +26,7 @@ def test_custom_openapi_includes_extra_schemas(client): """Test to ensure that extra schemas are correctly included in the OpenAPI schema.""" - response = client.get("/execution/openapi.json?version=2025-04-11") + response = client.get("/execution/openapi.json?version=2025-04-28") assert response.status_code == 200 openapi_schema = response.json() diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 62eb87c4242eb..a477c3ffc410f 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue -API_VERSION: Final[str] = "2025-04-11" +API_VERSION: Final[str] = "2025-04-28" class AssetAliasReferenceAssetEventDagRun(BaseModel): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 546d6f4203b61..065de8f1b404a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -34,7 +34,6 @@ import aiologic import attrs -import jinja2 import lazy_object_proxy import structlog from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter @@ -98,6 +97,7 @@ from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: + import jinja2 from pendulum.datetime import DateTime from structlog.typing import FilteringBoundLogger as Logger @@ -833,17 +833,18 @@ def run( msg = early_exit ti.state = state = TaskInstanceState.FAILED return state, msg, error - jinja_env = ti.task.dag.get_template_env() try: result = _execute_task(context, ti, log) except Exception: + import jinja2 + # If the task failed, swallow rendering error so it doesn't mask the main error. with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): - ti.rendered_map_index = _render_map_index(context, jinja_env=jinja_env, log=log) + ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) raise else: # If the task succeeded, render normally to let rendering error bubble up. - ti.rendered_map_index = _render_map_index(context, jinja_env=jinja_env, log=log) + ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) _push_xcom_if_needed(result, ti, log) @@ -1124,10 +1125,14 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): return result -def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None, log: Logger) -> str | None: +def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: """Render named map index if the DAG author defined map_index_template at the task level.""" - if jinja_env is None or (template := context.get("map_index_template")) is None: + if template := context.get("map_index_template") is None: + return None + if not isinstance(template, str): + log.error("Expected `template` to be a string, but got %s", type(template).__name__) return None + jinja_env = ti.task.dag.get_template_env() rendered_map_index = jinja_env.from_string(template).render(context) log.info("Map index rendered as %s", rendered_map_index) return rendered_map_index diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 9dd1b117057ea..5690f9b418ec6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1193,11 +1193,17 @@ def watched_subprocess(self, mocker): id="patch_task_instance_to_skipped", ), pytest.param( - RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + RetryTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" + ), b"", "task_instances.retry", (), - {"id": TI_ID, "end_date": timezone.parse("2024-10-31T12:00:00Z"), "rendered_map_index": None}, + { + "id": TI_ID, + "end_date": timezone.parse("2024-10-31T12:00:00Z"), + "rendered_map_index": "test retry task", + }, "", id="up_for_retry", ), @@ -1317,7 +1323,9 @@ def watched_subprocess(self, mocker): id="get_asset_events_by_asset_alias", ), pytest.param( - SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + SucceedTask( + end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" + ), b"", "task_instances.succeed", (), @@ -1326,7 +1334,7 @@ def watched_subprocess(self, mocker): "outlet_events": None, "task_outlets": None, "when": timezone.parse("2024-10-31T12:00:00Z"), - "rendered_map_index": None, + "rendered_map_index": "test success task", }, "", id="succeed_task", From aa4ee262b4445e5bbfc5fddc00be3a119f3b179f Mon Sep 17 00:00:00 2001 From: Purna Chander Date: Fri, 2 May 2025 11:58:06 +0530 Subject: [PATCH 4/5] test cases for render map index --- .../example_params_ui_tutorial.py | 0 .../src/airflow/jobs/triggerer_job_runner.py | 0 .../src/airflow/provider_info.schema.json | 0 .../ui/openapi-gen/queries/prefetch.ts | 0 airflow-core/src/airflow/ui/package.json | 0 .../components/DagActions/RunBackfillForm.tsx | 0 .../versions/head/test_task_instances.py | 34 +++++++++++++++++++ .../test_docker_compose_quick_start.py | 0 .../airflow/sdk/execution_time/task_runner.py | 23 +++++++++---- .../execution_time/test_task_runner.py | 31 +++++++++++++++++ 10 files changed, 81 insertions(+), 7 deletions(-) mode change 100755 => 100644 airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py mode change 100755 => 100644 airflow-core/src/airflow/jobs/triggerer_job_runner.py mode change 100755 => 100644 airflow-core/src/airflow/provider_info.schema.json mode change 100755 => 100644 airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts mode change 100755 => 100644 airflow-core/src/airflow/ui/package.json mode change 100755 => 100644 airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx mode change 100755 => 100644 docker-tests/tests/docker_tests/test_docker_compose_quick_start.py diff --git a/airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py b/airflow-core/src/airflow/example_dags/example_params_ui_tutorial.py old mode 100755 new mode 100644 diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py old mode 100755 new mode 100644 diff --git a/airflow-core/src/airflow/provider_info.schema.json b/airflow-core/src/airflow/provider_info.schema.json old mode 100755 new mode 100644 diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts old mode 100755 new mode 100644 diff --git a/airflow-core/src/airflow/ui/package.json b/airflow-core/src/airflow/ui/package.json old mode 100755 new mode 100644 diff --git a/airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx b/airflow-core/src/airflow/ui/src/components/DagActions/RunBackfillForm.tsx old mode 100755 new mode 100644 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 8af11c04df960..41bba2c2bd522 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -44,6 +44,7 @@ DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z") DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z") +DEFAULT_RENDERED_MAP_INDEX = "test rendered map index" def _create_asset_aliases(session, num: int = 2) -> None: @@ -465,6 +466,39 @@ def test_ti_update_state_to_terminal( assert ti.state == expected_state assert ti.end_date == end_date + @pytest.mark.parametrize( + ("state", "end_date", "expected_state", "rendered_map_index"), + [ + (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS, DEFAULT_RENDERED_MAP_INDEX), + (State.FAILED, DEFAULT_END_DATE, State.FAILED, DEFAULT_RENDERED_MAP_INDEX), + (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED, DEFAULT_RENDERED_MAP_INDEX), + ], + ) + def test_ti_update_state_to_terminal_with_rendered_map_index( + self, client, session, create_task_instance, state, end_date, expected_state, rendered_map_index + ): + ti = create_task_instance( + task_id="test_ti_update_state_to_terminal_with_rendered_map_index", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={"state": state, "end_date": end_date.isoformat(), "rendered_map_index": rendered_map_index}, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == expected_state + assert ti.end_date == end_date + assert ti.rendered_map_index == rendered_map_index + @pytest.mark.parametrize( "task_outlets", [ diff --git a/docker-tests/tests/docker_tests/test_docker_compose_quick_start.py b/docker-tests/tests/docker_tests/test_docker_compose_quick_start.py old mode 100755 new mode 100644 diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 065de8f1b404a..9092ee86f0b15 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -970,11 +970,19 @@ def _handle_trigger_dag_run( "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", dag_id=drte.trigger_dag_id, ) - msg = TaskState(state=TaskInstanceState.SKIPPED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) + msg = TaskState( + state=TaskInstanceState.SKIPPED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.SKIPPED else: log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) + msg = TaskState( + state=TaskInstanceState.FAILED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.FAILED return msg, state @@ -1020,7 +1028,11 @@ def _handle_trigger_dag_run( log.error( "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state ) - msg = TaskState(state=TaskInstanceState.FAILED, end_date=datetime.now(tz=timezone.utc), rendered_map_index=ti.rendered_map_index,) + msg = TaskState( + state=TaskInstanceState.FAILED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) state = TaskInstanceState.FAILED return msg, state if comms_msg.state in drte.allowed_states: @@ -1127,10 +1139,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: """Render named map index if the DAG author defined map_index_template at the task level.""" - if template := context.get("map_index_template") is None: - return None - if not isinstance(template, str): - log.error("Expected `template` to be a string, but got %s", type(template).__name__) + if (template := context.get("map_index_template")) is None: return None jinja_env = ti.task.dag.get_template_env() rendered_map_index = jinja_env.from_string(template).render(context) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f9ad46f357e7c..8df65b88ebf48 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -970,6 +970,37 @@ def test_function(): assert os.environ["AIRFLOW_CTX_TASK_ID"] == "test_env_task" +def test_execute_success_task_with_rendered_map_index(create_runtime_ti, mock_supervisor_comms): + """Test that the map index is rendered in the task context.""" + + def test_function(): + return "test function" + + task = PythonOperator( + task_id="test_task", + python_callable=test_function, + map_index_template="Hello! {{ run_id }}", + ) + + ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template") + + run(ti, ti.get_template_context(), log=mock.MagicMock()) + + assert ti.rendered_map_index == "Hello! test_run" + + +def test_execute_failed_task_with_rendered_map_index(create_runtime_ti, mock_supervisor_comms): + """Test that the map index is rendered in the task context.""" + + task = BaseOperator(task_id="test_task", map_index_template="Hello! {{ run_id }}") + + ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template") + + run(ti, ti.get_template_context(), log=mock.MagicMock()) + + assert ti.rendered_map_index == "Hello! test_run" + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" From 54be35a43a19c2b7e0ebb2a43d72ad85f9a2c051 Mon Sep 17 00:00:00 2001 From: Purna Chander Date: Mon, 5 May 2025 11:45:50 +0530 Subject: [PATCH 5/5] removed getattr to access rendered_map_index directly in supervisor --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 17556df66694c..b5cf977488b71 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -992,11 +992,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() - self._rendered_map_index = getattr(msg, "rendered_map_index", None) + self._rendered_map_index = msg.rendered_map_index elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() - self._rendered_map_index = getattr(msg, "rendered_map_index", None) + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.succeed( id=self.id, when=msg.end_date, @@ -1007,7 +1007,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): elif isinstance(msg, RetryTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() - self._rendered_map_index = getattr(msg, "rendered_map_index", None) + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.retry( id=self.id, end_date=msg.end_date, @@ -1054,10 +1054,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = xcom elif isinstance(msg, DeferTask): self._terminal_state = TaskInstanceState.DEFERRED - self._rendered_map_index = getattr(msg, "rendered_map_index", None) + self._rendered_map_index = msg.rendered_map_index self.client.task_instances.defer(self.id, msg) elif isinstance(msg, RescheduleTask): - self._rendered_map_index = getattr(msg, "rendered_map_index", None) self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE self.client.task_instances.reschedule(self.id, msg) elif isinstance(msg, SkipDownstreamTasks):