diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index b6bbf67fae3c2..f0e9bd9c6dabd 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -460,6 +460,15 @@ def write_dag( @classmethod def latest_item_select_object(cls, dag_id): + from airflow.settings import engine + + if engine.dialect.name == "mysql": + # Prevent "Out of sort memory" caused by large values in cls.data column for MySQL. + # Details in https://github.com/apache/airflow/pull/55589 + latest_item_id = ( + select(cls.id).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1) + ) + return select(cls).where(cls.id == latest_item_id) return select(cls).where(cls.dag_id == dag_id).order_by(cls.created_at.desc()).limit(1) @classmethod diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py b/scripts/ci/prek/check_template_context_variable_in_sync.py index 0b74e4beedbd8..1c55fbd19208e 100755 --- a/scripts/ci/prek/check_template_context_variable_in_sync.py +++ b/scripts/ci/prek/check_template_context_variable_in_sync.py @@ -83,17 +83,25 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]: yield key.value # Extract keys from the main `context` dictionary assignment - context_assignment = next( + context_assignment: ast.AnnAssign = next( stmt for stmt in fn_get_template_context.body if isinstance(stmt, ast.AnnAssign) - and isinstance(stmt.target, ast.Name) - and stmt.target.id == "context" + and isinstance(stmt.target, ast.Attribute) + and isinstance(stmt.target.value, ast.Name) + and stmt.target.value.id == "self" + and stmt.target.attr == "_context" ) - if not isinstance(context_assignment.value, ast.Dict): + if not isinstance(context_assignment.value, ast.BoolOp): + raise TypeError("Expected a BoolOp like 'self._context or {...}'.") + + context_assignment_op = context_assignment.value + _, context_assignment_value = context_assignment_op.values + + if not isinstance(context_assignment_value, ast.Dict): raise ValueError("'context' is not assigned a dictionary literal") - yield from extract_keys_from_dict(context_assignment.value) + yield from extract_keys_from_dict(context_assignment_value) # Handle keys added conditionally in `if from_server` for stmt in fn_get_template_context.body: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 409982d1a6bf9..7138598027126 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -131,6 +131,9 @@ class RuntimeTaskInstance(TaskInstance): task: BaseOperator bundle_instance: BaseDagBundle + _context: Context | None = None + """The Task Instance context.""" + _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None """The Task Instance context from the API server, if any.""" @@ -173,7 +176,9 @@ def get_template_context(self) -> Context: validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) - context: Context = { + # Cache the context object, which ensures that all calls to get_template_context + # are operating on the same context object. + self._context: Context = self._context or { # From the Task Execution interface "dag": self.task.dag, "inlets": self.task.inlets, @@ -213,7 +218,7 @@ def get_template_context(self) -> Context: lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) ), } - context.update(context_from_server) + self._context.update(context_from_server) if logical_date := coerce_datetime(dag_run.logical_date): if TYPE_CHECKING: @@ -224,7 +229,7 @@ def get_template_context(self) -> Context: ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") ts_nodash_with_tz = ts.replace("-", "").replace(":", "") # logical_date and data_interval either coexist or be None together - context.update( + self._context.update( { # keys that depend on logical_date "logical_date": logical_date, @@ -251,7 +256,7 @@ def get_template_context(self) -> Context: # existence. Should this be a private attribute on RuntimeTI instead perhaps? setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) - return context + return self._context def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 930d80590ea7b..a2a19266ae5aa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -65,7 +65,7 @@ ) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet -from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model from airflow.sdk.definitions.param import DagParam from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( @@ -2482,6 +2482,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error): def before_stopping(self, component): self.component = component + class CustomOutletEventsListener: + def __init__(self): + self.outlet_events = [] + self.error = None + + def _add_outlet_events(self, context): + outlets = context["outlets"] + for outlet in outlets: + self.outlet_events.append(context["outlet_events"][outlet]) + + @hookimpl + def on_task_instance_running(self, previous_state, task_instance): + context = task_instance.get_template_context() + self._add_outlet_events(context) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance): + context = task_instance.get_template_context() + self._add_outlet_events(context) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance, error): + context = task_instance.get_template_context() + self._add_outlet_events(context) + self.error = error + @pytest.fixture(autouse=True) def clean_listener_manager(self): lm = get_listener_manager() @@ -2601,6 +2627,118 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert listener.error == error + def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms): + """Test listener can access outlet events through invoking get_template_context() while task running and success""" + listener = self.CustomOutletEventsListener() + get_listener_manager().add_listener(listener) + + test_asset = Asset("test-asset") + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} + + class Producer(BaseOperator): + def execute(self, context): + outlet_events = context["outlet_events"] + outlet_events[test_asset].extra = test_extra + + task = Producer( + task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset] + ) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct( + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() + ) + + log = mock.MagicMock() + context = runtime_ti.get_template_context() + + with mock.patch( + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" + ) as validate_mock: + state, _, _ = run(runtime_ti, context, log) + + validate_mock.assert_called_once() + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + finalize(runtime_ti, state, context, log) + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + @pytest.mark.parametrize( + "exception", + [ + ValueError("oops"), + SystemExit("oops"), + AirflowException("oops"), + ], + ids=["ValueError", "SystemExit", "AirflowException"], + ) + def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception): + """Test listener can access outlet events through invoking get_template_context() while task failed""" + listener = self.CustomOutletEventsListener() + get_listener_manager().add_listener(listener) + + test_asset = Asset("test-asset") + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} + + class Producer(BaseOperator): + def execute(self, context): + outlet_events = context["outlet_events"] + outlet_events[test_asset].extra = test_extra + raise exception + + task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset]) + dag = get_inline_dag(dag_id="test_dag", task=task) + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id=dag.dag_id, + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ) + + runtime_ti = RuntimeTaskInstance.model_construct( + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() + ) + + log = mock.MagicMock() + context = runtime_ti.get_template_context() + + with mock.patch( + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" + ) as validate_mock: + state, _, error = run(runtime_ti, context, log) + + validate_mock.assert_called_once() + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + finalize(runtime_ti, state, context, log, error) + + outlet_event_accessor = listener.outlet_events.pop() + assert outlet_event_accessor.key == test_key + assert outlet_event_accessor.extra == test_extra + + assert listener.error == error + @pytest.mark.usefixtures("mock_supervisor_comms") class TestTaskRunnerCallsCallbacks: