diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 48f4f17601a90..3c512be9cd594 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -812,6 +812,9 @@ def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[d def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: ti.hostname = get_hostname() ti.task = ti.task.prepare_for_execution() + # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to + # update the value of the task that is sent from there + context["task"] = ti.task jinja_env = ti.task.dag.get_template_env() ti.render_templates(context=context, jinja_env=jinja_env) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 05dc200f7f91a..0a7261ae586af 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -603,6 +603,7 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm spy_agency.assert_spy_called(task.prepare_for_execution) assert ti.task._lock_for_execution assert ti.task is not task, "ti.task should be a copy of the original task" + assert ti.task is ti.get_template_context()["task"], "task in context should be updated too" assert ti.state == TaskInstanceState.SUCCESS mock_supervisor_comms.send.assert_any_call(