Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/ci/prek/check_template_context_variable_in_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
and isinstance(stmt.target, ast.Attribute)
and isinstance(stmt.target.value, ast.Name)
and stmt.target.value.id == "self"
and stmt.target.attr == "_context"
and stmt.target.attr == "_cached_template_context"
)

if not isinstance(context_assignment.value, ast.BoolOp):
raise TypeError("Expected a BoolOp like 'self._context or {...}'.")
raise TypeError("Expected a BoolOp like 'self._cached_template_context or {...}'.")

context_assignment_op = context_assignment.value
_, context_assignment_value = context_assignment_op.values
Expand Down
12 changes: 6 additions & 6 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ class RuntimeTaskInstance(TaskInstance):

task: BaseOperator
bundle_instance: BaseDagBundle
_context: Context | None = None
"""The Task Instance context."""
_cached_template_context: Context | None = None
"""The Task Instance context. This is used to cache get_template_context."""

_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
"""The Task Instance context from the API server, if any."""
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_template_context(self) -> Context:

# Cache the context object, which ensures that all calls to get_template_context
# are operating on the same context object.
self._context: Context = self._context or {
self._cached_template_context: Context = self._cached_template_context or {
# From the Task Execution interface
"dag": self.task.dag,
"inlets": self.task.inlets,
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_template_context(self) -> Context:
lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
),
}
self._context.update(context_from_server)
self._cached_template_context.update(context_from_server)

if logical_date := coerce_datetime(dag_run.logical_date):
if TYPE_CHECKING:
Expand All @@ -229,7 +229,7 @@ def get_template_context(self) -> Context:
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
# logical_date and data_interval either coexist or be None together
self._context.update(
self._cached_template_context.update(
{
# keys that depend on logical_date
"logical_date": logical_date,
Expand All @@ -256,7 +256,7 @@ def get_template_context(self) -> Context:
# existence. Should this be a private attribute on RuntimeTI instead perhaps?
setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)

return self._context
return self._cached_template_context

def render_templates(
self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
Expand Down