diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py b/scripts/ci/prek/check_template_context_variable_in_sync.py index dec8708e2825e..01f4c461c6802 100755 --- a/scripts/ci/prek/check_template_context_variable_in_sync.py +++ b/scripts/ci/prek/check_template_context_variable_in_sync.py @@ -90,11 +90,11 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]: and isinstance(stmt.target, ast.Attribute) and isinstance(stmt.target.value, ast.Name) and stmt.target.value.id == "self" - and stmt.target.attr == "_context" + and stmt.target.attr == "_cached_template_context" ) if not isinstance(context_assignment.value, ast.BoolOp): - raise TypeError("Expected a BoolOp like 'self._context or {...}'.") + raise TypeError("Expected a BoolOp like 'self._cached_template_context or {...}'.") context_assignment_op = context_assignment.value _, context_assignment_value = context_assignment_op.values diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index e4fd3ace38c75..dbcd6329bce5d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -131,8 +131,8 @@ class RuntimeTaskInstance(TaskInstance): task: BaseOperator bundle_instance: BaseDagBundle - _context: Context | None = None - """The Task Instance context.""" + _cached_template_context: Context | None = None + """The Task Instance context. This is used to cache get_template_context.""" _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None """The Task Instance context from the API server, if any.""" @@ -178,7 +178,7 @@ def get_template_context(self) -> Context: # Cache the context object, which ensures that all calls to get_template_context # are operating on the same context object. - self._context: Context = self._context or { + self._cached_template_context: Context = self._cached_template_context or { # From the Task Execution interface "dag": self.task.dag, "inlets": self.task.inlets, @@ -218,7 +218,7 @@ def get_template_context(self) -> Context: lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) ), } - self._context.update(context_from_server) + self._cached_template_context.update(context_from_server) if logical_date := coerce_datetime(dag_run.logical_date): if TYPE_CHECKING: @@ -229,7 +229,7 @@ def get_template_context(self) -> Context: ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") ts_nodash_with_tz = ts.replace("-", "").replace(":", "") # logical_date and data_interval either coexist or be None together - self._context.update( + self._cached_template_context.update( { # keys that depend on logical_date "logical_date": logical_date, @@ -256,7 +256,7 @@ def get_template_context(self) -> Context: # existence. Should this be a private attribute on RuntimeTI instead perhaps? setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) - return self._context + return self._cached_template_context def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None