diff --git a/airflow-core/src/airflow/decorators/condition.py b/airflow-core/src/airflow/decorators/condition.py index e276b9fc7178e..06fd01391f28b 100644 --- a/airflow-core/src/airflow/decorators/condition.py +++ b/airflow-core/src/airflow/decorators/condition.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias - from airflow.models.baseoperator import TaskPreExecuteHook + from airflow.sdk.definitions.baseoperator import TaskPreExecuteHook from airflow.sdk.definitions.context import Context BoolConditionFunc: TypeAlias = Callable[[Context], bool] diff --git a/airflow-core/src/airflow/models/baseoperator.py b/airflow-core/src/airflow/models/baseoperator.py index 3ecc6ebc44efb..77935a5991e51 100644 --- a/airflow-core/src/airflow/models/baseoperator.py +++ b/airflow-core/src/airflow/models/baseoperator.py @@ -29,13 +29,7 @@ from collections.abc import Collection, Iterable, Iterator from datetime import datetime, timedelta from functools import singledispatchmethod -from types import FunctionType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - TypeVar, -) +from typing import TYPE_CHECKING, Any import methodtools import pendulum @@ -62,7 +56,6 @@ cross_downstream as cross_downstream, get_merged_defaults as get_merged_defaults, ) -from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup @@ -73,8 +66,6 @@ from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone -from airflow.utils.context import context_get_outlet_events -from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState @@ -86,16 +77,11 @@ from airflow.models.dag import DAG as SchedulerDAG from airflow.models.operator import Operator - from airflow.sdk import BaseOperatorLink + from airflow.sdk import BaseOperatorLink, Context from airflow.sdk.definitions._internal.node import DAGNode from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import StartTriggerArgs -TaskPreExecuteHook = Callable[[Context], None] -TaskPostExecuteHook = Callable[[Context, Any], None] - -T = TypeVar("T", bound=FunctionType) - logger = logging.getLogger("airflow.models.baseoperator.BaseOperator") @@ -338,20 +324,12 @@ def say_hello_world(**context): start_trigger_args: StartTriggerArgs | None = None start_from_trigger: bool = False - def __init__( - self, - pre_execute=None, - post_execute=None, - **kwargs, - ): + def __init__(self, **kwargs): if start_date := kwargs.get("start_date", None): kwargs["start_date"] = timezone.convert_to_utc(start_date) - if end_date := kwargs.get("end_date", None): kwargs["end_date"] = timezone.convert_to_utc(end_date) super().__init__(**kwargs) - self._pre_execute_hook = pre_execute - self._post_execute_hook = post_execute # Defines the operator level extra links operator_extra_links: Collection[BaseOperatorLink] = () @@ -411,7 +389,10 @@ def pre_execute(self, context: Any): """Execute right before self.execute() is called.""" if self._pre_execute_hook is None: return - ExecutionCallableRunner( + from airflow.sdk.execution_time.callback_runner import create_executable_runner + from airflow.sdk.execution_time.context import context_get_outlet_events + + create_executable_runner( self._pre_execute_hook, context_get_outlet_events(context), logger=self.log, @@ -436,7 +417,10 @@ def post_execute(self, context: Any, result: Any = None): """ if self._post_execute_hook is None: return - ExecutionCallableRunner( + from airflow.sdk.execution_time.callback_runner import create_executable_runner + from airflow.sdk.execution_time.context import context_get_outlet_events + + create_executable_runner( self._post_execute_hook, context_get_outlet_events(context), logger=self.log, diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 081d4e0d115cf..66975f57815ea 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -116,7 +116,6 @@ from airflow.utils.helpers import prune_dict, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname -from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -640,13 +639,14 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper ) def _execute_callable(context: Context, **execute_callable_kwargs): - from airflow.utils.context import context_get_outlet_events + from airflow.sdk.execution_time.callback_runner import create_executable_runner + from airflow.sdk.execution_time.context import context_get_outlet_events try: # Print a marker for log grouping of details before task execution log.info("::endgroup::") - return ExecutionCallableRunner( + return create_executable_runner( execute_callable, context_get_outlet_events(context), logger=log, diff --git a/airflow-core/src/airflow/utils/context.py b/airflow-core/src/airflow/utils/context.py index 507211492ebf1..e8cfdc499cc8f 100644 --- a/airflow-core/src/airflow/utils/context.py +++ b/airflow-core/src/airflow/utils/context.py @@ -44,7 +44,6 @@ if TYPE_CHECKING: from airflow.sdk.definitions.asset import Asset - from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: # * Context in task-sdk/src/airflow/sdk/definitions/context.py @@ -176,10 +175,3 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context: """ new = {k: v for k, v in source.items() if k in keys} return cast(Context, new) - - -def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: - try: - return context["outlet_events"] - except KeyError: - return OutletEventAccessors() diff --git a/airflow-core/src/airflow/utils/operator_helpers.py b/airflow-core/src/airflow/utils/operator_helpers.py index ef7cd35203902..bf340c1ace9bc 100644 --- a/airflow-core/src/airflow/utils/operator_helpers.py +++ b/airflow-core/src/airflow/utils/operator_helpers.py @@ -18,17 +18,9 @@ from __future__ import annotations import inspect -import logging from collections.abc import Collection, Mapping -from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar +from typing import Any, Callable, TypeVar -from airflow.typing_compat import ParamSpec -from airflow.utils.types import NOTSET - -if TYPE_CHECKING: - from airflow.sdk.types import OutletEventAccessorsProtocol - -P = ParamSpec("P") R = TypeVar("R") @@ -58,7 +50,6 @@ def determine( args: Collection[Any], kwargs: Mapping[str, Any], ) -> KeywordParameters: - import inspect import itertools signature = inspect.signature(func) @@ -119,65 +110,3 @@ def kwargs_func(*args, **kwargs): return func(*args, **kwargs) return kwargs_func - - -class _ExecutionCallableRunner(Protocol): - @staticmethod - def run(*args, **kwargs): ... - - -def ExecutionCallableRunner( - func: Callable[P, R], - outlet_events: OutletEventAccessorsProtocol, - *, - logger: logging.Logger, -) -> _ExecutionCallableRunner: - """ - Run an execution callable against a task context and given arguments. - - If the callable is a simple function, this simply calls it with the supplied - arguments (including the context). If the callable is a generator function, - the generator is exhausted here, with the yielded values getting fed back - into the task context automatically for execution. - - This convoluted implementation of inner class with closure is so *all* - arguments passed to ``run()`` can be forwarded to the wrapped function. This - is particularly important for the argument "self", which some use cases - need to receive. This is not possible if this is implemented as a normal - class, where "self" needs to point to the ExecutionCallableRunner object. - - The function name violates PEP 8 due to backward compatibility. This was - implemented as a class previously. - - :meta private: - """ - - class _ExecutionCallableRunnerImpl: - @staticmethod - def run(*args: P.args, **kwargs: P.kwargs) -> R: - from airflow.sdk.definitions.asset.metadata import Metadata - - if not inspect.isgeneratorfunction(func): - return func(*args, **kwargs) - - result: Any = NOTSET - - def _run(): - nonlocal result - result = yield from func(*args, **kwargs) - - for metadata in _run(): - if isinstance(metadata, Metadata): - outlet_events[metadata.asset].extra.update(metadata.extra) - - if metadata.alias: - outlet_events[metadata.alias].add(metadata.asset, extra=metadata.extra) - - continue - logger.warning("Ignoring unknown data of %r received from task", type(metadata)) - if logger.isEnabledFor(logging.DEBUG): - logger.debug("Full yielded value: %r", metadata) - - return result - - return _ExecutionCallableRunnerImpl diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index f0d78e5a58b3e..4cfa53f7be009 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -1350,6 +1350,8 @@ def test_no_new_fields_added_to_base_operator(self): assert fields == { "_logger_name": None, "_needs_expansion": None, + "_post_execute_hook": None, + "_pre_execute_hook": None, "_task_display_name": None, "allow_nested_operators": True, "depends_on_past": False, diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 2ac571eac2599..8523a9d670b98 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -66,10 +66,12 @@ from pendulum.datetime import DateTime + from airflow.sdk.execution_time.callback_runner import ExecutionCallableRunner + from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol + try: from airflow.sdk.definitions.context import Context - except ImportError: - # TODO: Remove once provider drops support for Airflow 2 + except ImportError: # TODO: Remove once provider drops support for Airflow 2 from airflow.utils.context import Context _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] @@ -190,14 +192,22 @@ def execute(self, context: Context) -> Any: context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = self.determine_kwargs(context) - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.context import context_get_outlet_events + # This needs to be lazy because subclasses may implement execute_callable + # by running a separate process that can't use the eager result. + def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None: + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.callback_runner import create_executable_runner + from airflow.sdk.execution_time.context import context_get_outlet_events + + return create_executable_runner, context_get_outlet_events(context) + if AIRFLOW_V_2_10_PLUS: + from airflow.utils.context import context_get_outlet_events # type: ignore + from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore - self._asset_events = context_get_outlet_events(context) - elif AIRFLOW_V_2_10_PLUS: - from airflow.utils.context import context_get_outlet_events + return ExecutionCallableRunner, context_get_outlet_events(context) + return None - self._dataset_events = context_get_outlet_events(context) + self.__prepare_execution = __prepare_execution return_value = self.execute_callable() if self.show_return_value_in_logs: @@ -210,19 +220,18 @@ def execute(self, context: Context) -> Any: def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking() + __prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None] + def execute_callable(self) -> Any: """ Call the python callable with the given arguments. :return: the return value of the call. """ - try: - from airflow.utils.operator_helpers import ExecutionCallableRunner - except ImportError: - # Handle Pre Airflow 2.10 case where ExecutionCallableRunner was not available + if (execution_preparation := self.__prepare_execution()) is None: return self.python_callable(*self.op_args, **self.op_kwargs) - asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events - runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log) + create_execution_runner, asset_events = execution_preparation + runner = create_execution_runner(self.python_callable, asset_events, logger=self.log) return runner.run(*self.op_args, **self.op_kwargs) diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py b/task-sdk/src/airflow/sdk/definitions/baseoperator.py index 5edfb0c66688e..ce1f16737a622 100644 --- a/task-sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/baseoperator.py @@ -87,6 +87,9 @@ from airflow.typing_compat import Self from airflow.utils.operator_resources import Resources + TaskPreExecuteHook = Callable[[Context], None] + TaskPostExecuteHook = Callable[[Context, Any], None] + __all__ = [ "BaseOperator", "chain", @@ -822,8 +825,8 @@ def say_hello_world(**context): on_success_callback: Sequence[TaskStateChangeCallback] = () on_retry_callback: Sequence[TaskStateChangeCallback] = () on_skipped_callback: Sequence[TaskStateChangeCallback] = () - # pre_execute: TaskPreExecuteHook | None = None - # post_execute: TaskPostExecuteHook | None = None + _pre_execute_hook: TaskPreExecuteHook | None = None + _post_execute_hook: TaskPostExecuteHook | None = None trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE resources: dict[str, Any] | None = None run_as_user: str | None = None @@ -981,8 +984,8 @@ def __init__( on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, - # pre_execute: TaskPreExecuteHook | None = None, - # post_execute: TaskPostExecuteHook | None = None, + pre_execute: TaskPreExecuteHook | None = None, + post_execute: TaskPostExecuteHook | None = None, trigger_rule: str = DEFAULT_TRIGGER_RULE, resources: dict[str, Any] | None = None, run_as_user: str | None = None, @@ -1053,14 +1056,13 @@ def __init__( ) self.execution_timeout = execution_timeout - # TODO: self.on_execute_callback = _collect_callbacks(on_execute_callback) self.on_failure_callback = _collect_callbacks(on_failure_callback) self.on_success_callback = _collect_callbacks(on_success_callback) self.on_retry_callback = _collect_callbacks(on_retry_callback) self.on_skipped_callback = _collect_callbacks(on_skipped_callback) - # self._pre_execute_hook = pre_execute - # self._post_execute_hook = post_execute + self._pre_execute_hook = pre_execute + self._post_execute_hook = post_execute if start_date: self.start_date = timezone.convert_to_utc(start_date) diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py new file mode 100644 index 0000000000000..31f701074cf3d --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import TYPE_CHECKING, Callable, Generic, Protocol, TypeVar, cast + +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk.types import OutletEventAccessorsProtocol + +P = ParamSpec("P") +R = TypeVar("R") + + +class _ExecutionCallableRunner(Generic[P, R]): + @staticmethod + def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type: ignore[empty-body] + + +class ExecutionCallableRunner(Protocol): + def __call__( + self, + func: Callable[P, R], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, + ) -> _ExecutionCallableRunner[P, R]: ... + + +def create_executable_runner( + func: Callable[P, R], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, +) -> _ExecutionCallableRunner[P, R]: + """ + Run an execution callable against a task context and given arguments. + + If the callable is a simple function, this simply calls it with the supplied + arguments (including the context). If the callable is a generator function, + the generator is exhausted here, with the yielded values getting fed back + into the task context automatically for execution. + + This convoluted implementation of inner class with closure is so *all* + arguments passed to ``run()`` can be forwarded to the wrapped function. This + is particularly important for the argument "self", which some use cases + need to receive. This is not possible if this is implemented as a normal + class, where "self" needs to point to the runner object, not the object + bounded to the inner callable. + + :meta private: + """ + + class _ExecutionCallableRunnerImpl(_ExecutionCallableRunner): + @staticmethod + def run(*args: P.args, **kwargs: P.kwargs) -> R: + from airflow.sdk.definitions.asset.metadata import Metadata + + if not inspect.isgeneratorfunction(func): + return func(*args, **kwargs) + + result: R + + if isinstance(logger, logging.Logger): + + def _warn_unknown(metadata): + logger.warning("Ignoring unknown data of %r received from task", type(metadata)) + logger.debug("Full yielded value: %r", metadata) + else: + + def _warn_unknown(metadata): + logger.warning("Ignoring unknown type received from task", type=type(metadata)) + logger.debug("Full yielded value", metadata=metadata) + + def _run(): + nonlocal result + result = yield from func(*args, **kwargs) + + for metadata in _run(): + if isinstance(metadata, Metadata): + outlet_events[metadata.asset].extra.update(metadata.extra) + if metadata.alias: + outlet_events[metadata.alias].add(metadata.asset, extra=metadata.extra) + else: + _warn_unknown(metadata) + + return result # noqa: F821 # Ruff is not smart enough to know this is always set in _run(). + + return cast(_ExecutionCallableRunner[P, R], _ExecutionCallableRunnerImpl) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index c4d6d9c439ce1..ce14850906d60 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -53,6 +53,7 @@ PrevSuccessfulDagRunResponse, VariableResult, ) + from airflow.sdk.types import OutletEventAccessorsProtocol DEFAULT_FORMAT_PREFIX = "airflow.ctx." @@ -561,3 +562,11 @@ def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool params[mapping_value] = str(_attr) return params + + +def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: + try: + outlet_events = context["outlet_events"] + except KeyError: + outlet_events = context["outlet_events"] = OutletEventAccessors() + return outlet_events diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 17a3451155e37..d6c22f96ec816 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -24,7 +24,7 @@ import os import sys import time -from collections.abc import Iterable, Iterator, Mapping +from collections.abc import Callable, Iterable, Iterator, Mapping from datetime import datetime, timezone from io import FileIO from itertools import product @@ -53,6 +53,7 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( DagRunStateResult, DeferTask, @@ -79,6 +80,7 @@ MacrosAccessor, OutletEventAccessors, VariableAccessor, + context_get_outlet_events, context_to_airflow_vars, get_previous_dagrun_success, set_current_context, @@ -834,9 +836,10 @@ def _run_task_state_change_callbacks( context: Context, log: Logger, ) -> None: + callback: Callable[[Context], None] for i, callback in enumerate(getattr(task, kind)): try: - callback(context) + create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) except Exception: log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) @@ -863,6 +866,11 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) os.environ.update(airflow_context_vars) + outlet_events = context_get_outlet_events(context) + + if (pre_execute_hook := task._pre_execute_hook) is not None: + create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) + _run_task_state_change_callbacks(task, "on_execute_callback", context, log) if task.execution_timeout: @@ -882,6 +890,10 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): raise else: result = ctx.run(execute, context=context) + + if (post_execute_hook := task._post_execute_hook) is not None: + create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) + return result