Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions scripts/ci/prek/check_template_context_variable_in_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,25 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
yield key.value

# Extract keys from the main `context` dictionary assignment
context_assignment = next(
context_assignment: ast.AnnAssign = next(
stmt
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "context"
and isinstance(stmt.target, ast.Attribute)
and isinstance(stmt.target.value, ast.Name)
and stmt.target.value.id == "self"
and stmt.target.attr == "_context"
)

if not isinstance(context_assignment.value, ast.Dict):
if not isinstance(context_assignment.value, ast.BoolOp):
raise TypeError("Expected a BoolOp like 'self._context or {...}'.")

context_assignment_op = context_assignment.value
_, context_assignment_value = context_assignment_op.values

if not isinstance(context_assignment_value, ast.Dict):
raise ValueError("'context' is not assigned a dictionary literal")
yield from extract_keys_from_dict(context_assignment.value)
yield from extract_keys_from_dict(context_assignment_value)

# Handle keys added conditionally in `if from_server`
for stmt in fn_get_template_context.body:
Expand Down
13 changes: 9 additions & 4 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class RuntimeTaskInstance(TaskInstance):

task: BaseOperator
bundle_instance: BaseDagBundle
_context: Context | None = None
"""The Task Instance context."""

_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
"""The Task Instance context from the API server, if any."""

Expand Down Expand Up @@ -173,7 +176,9 @@ def get_template_context(self) -> Context:

validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)

context: Context = {
# Cache the context object, which ensures that all calls to get_template_context
# are operating on the same context object.
self._context: Context = self._context or {
# From the Task Execution interface
"dag": self.task.dag,
"inlets": self.task.inlets,
Expand Down Expand Up @@ -215,7 +220,7 @@ def get_template_context(self) -> Context:
lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
),
}
context.update(context_from_server)
self._context.update(context_from_server)

if logical_date := coerce_datetime(dag_run.logical_date):
if TYPE_CHECKING:
Expand All @@ -226,7 +231,7 @@ def get_template_context(self) -> Context:
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
# logical_date and data_interval either coexist or be None together
context.update(
self._context.update(
{
# keys that depend on logical_date
"logical_date": logical_date,
Expand All @@ -253,7 +258,7 @@ def get_template_context(self) -> Context:
# existence. Should this be a private attribute on RuntimeTI instead perhaps?
setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)

return context
return self._context

def render_templates(
self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
Expand Down
140 changes: 139 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet
from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
Expand Down Expand Up @@ -2502,6 +2502,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error):
def before_stopping(self, component):
self.component = component

class CustomOutletEventsListener:
def __init__(self):
self.outlet_events = []
self.error = None

def _add_outlet_events(self, context):
outlets = context["outlets"]
for outlet in outlets:
self.outlet_events.append(context["outlet_events"][outlet])

@hookimpl
def on_task_instance_running(self, previous_state, task_instance):
context = task_instance.get_template_context()
self._add_outlet_events(context)

@hookimpl
def on_task_instance_success(self, previous_state, task_instance):
context = task_instance.get_template_context()
self._add_outlet_events(context)

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, error):
context = task_instance.get_template_context()
self._add_outlet_events(context)
self.error = error

@pytest.fixture(autouse=True)
def clean_listener_manager(self):
lm = get_listener_manager()
Expand Down Expand Up @@ -2622,6 +2648,118 @@ def execute(self, context):
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
assert listener.error == error

def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms):
"""Test listener can access outlet events through invoking get_template_context() while task running and success"""
listener = self.CustomOutletEventsListener()
get_listener_manager().add_listener(listener)

test_asset = Asset("test-asset")
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}

class Producer(BaseOperator):
def execute(self, context):
outlet_events = context["outlet_events"]
outlet_events[test_asset].extra = test_extra

task = Producer(
task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset]
)
dag = get_inline_dag(dag_id="test_dag", task=task)
ti = TaskInstance(
id=uuid7(),
task_id=task.task_id,
dag_id=dag.dag_id,
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
)

runtime_ti = RuntimeTaskInstance.model_construct(
**ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow()
)

log = mock.MagicMock()
context = runtime_ti.get_template_context()

with mock.patch(
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
) as validate_mock:
state, _, _ = run(runtime_ti, context, log)

validate_mock.assert_called_once()

outlet_event_accessor = listener.outlet_events.pop()
assert outlet_event_accessor.key == test_key
assert outlet_event_accessor.extra == test_extra

finalize(runtime_ti, state, context, log)

outlet_event_accessor = listener.outlet_events.pop()
assert outlet_event_accessor.key == test_key
assert outlet_event_accessor.extra == test_extra

@pytest.mark.parametrize(
"exception",
[
ValueError("oops"),
SystemExit("oops"),
AirflowException("oops"),
],
ids=["ValueError", "SystemExit", "AirflowException"],
)
def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception):
"""Test listener can access outlet events through invoking get_template_context() while task failed"""
listener = self.CustomOutletEventsListener()
get_listener_manager().add_listener(listener)

test_asset = Asset("test-asset")
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}

class Producer(BaseOperator):
def execute(self, context):
outlet_events = context["outlet_events"]
outlet_events[test_asset].extra = test_extra
raise exception

task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset])
dag = get_inline_dag(dag_id="test_dag", task=task)
ti = TaskInstance(
id=uuid7(),
task_id=task.task_id,
dag_id=dag.dag_id,
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
)

runtime_ti = RuntimeTaskInstance.model_construct(
**ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow()
)

log = mock.MagicMock()
context = runtime_ti.get_template_context()

with mock.patch(
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
) as validate_mock:
state, _, error = run(runtime_ti, context, log)

validate_mock.assert_called_once()

outlet_event_accessor = listener.outlet_events.pop()
assert outlet_event_accessor.key == test_key
assert outlet_event_accessor.extra == test_extra

finalize(runtime_ti, state, context, log, error)

outlet_event_accessor = listener.outlet_events.pop()
assert outlet_event_accessor.key == test_key
assert outlet_event_accessor.extra == test_extra

assert listener.error == error


@pytest.mark.usefixtures("mock_supervisor_comms")
class TestTaskRunnerCallsCallbacks:
Expand Down