diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 136978e238101..b4974b00e5293 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -219,7 +219,7 @@ def test_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -377,7 +377,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -441,7 +441,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -497,7 +497,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -617,7 +617,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -1409,7 +1409,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): False, "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "example_python_operator"}, - 9, # Based on test failure - example_python_operator creates 9 task instances + 14, # Based on test failure - example_python_operator creates 14 task instances 3, id="test dag_id_pattern exact match", ), @@ -1418,7 +1418,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): False, "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "example_%"}, - 17, # Based on test failure - both DAGs together create 17 task instances + 22, # Based on test failure - both DAGs together create 22 task instances 3, id="test dag_id_pattern wildcard prefix", ), @@ -1932,8 +1932,8 @@ def test_should_respond_200_when_task_instance_properties_are_none( [ pytest.param( {"dag_ids": ["example_python_operator", "example_skip_dag"]}, - 17, - 17, + 22, + 22, id="with dag filter", ), ], @@ -2042,7 +2042,7 @@ def test_should_respond_200_for_pagination(self, test_client, session): assert len(response_batch2.json()["task_instances"]) > 0 # Match - ti_count = 9 + ti_count = 10 assert response_batch1.json()["total_entries"] == response_batch2.json()["total_entries"] == ti_count assert (num_entries_batch1 + num_entries_batch2) == ti_count assert response_batch1 != response_batch2 @@ -2081,7 +2081,7 @@ def test_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -2128,7 +2128,7 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -2206,7 +2206,7 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers( "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -2279,7 +2279,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -2327,7 +2327,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3244,7 +3244,7 @@ def test_should_respond_200_with_dag_run_id( "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3616,7 +3616,7 @@ def test_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3654,7 +3654,7 @@ def test_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3802,7 +3802,7 @@ def test_ti_in_retry_state_not_returned(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3886,7 +3886,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -3924,7 +3924,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4170,7 +4170,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4444,7 +4444,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4578,7 +4578,7 @@ def test_update_mask_set_note_should_respond_200( "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4639,7 +4639,7 @@ def test_set_note_should_respond_200(self, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4718,7 +4718,7 @@ def test_set_note_should_respond_200_mapped_task_with_rtif(self, test_client, se "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4799,7 +4799,7 @@ def test_set_note_should_respond_200_mapped_task_summary_with_rtif(self, test_cl "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -4917,7 +4917,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, @@ -5203,7 +5203,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte "pid": 100, "pool": "default_pool", "pool_slots": 1, - "priority_weight": 9, + "priority_weight": 14, "queue": "default_queue", "queued_when": None, "scheduled_when": None, diff --git a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py index 6b77db3e4a904..abe3308149dfc 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py +++ b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py @@ -17,18 +17,230 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.providers.common.compat._compat_utils import create_module_getattr +from airflow.providers.common.compat.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + AIRFLOW_V_3_2_PLUS, +) _IMPORT_MAP: dict[str, str | tuple[str, ...]] = { # Re-export from sdk (which handles Airflow 2.x/3.x fallbacks) + "AsyncExecutionCallableRunner": "airflow.providers.common.compat.sdk", "BaseOperator": "airflow.providers.common.compat.sdk", + "BaseAsyncOperator": "airflow.providers.common.compat.sdk", + "create_async_executable_runner": "airflow.providers.common.compat.sdk", "get_current_context": "airflow.providers.common.compat.sdk", + "is_async_callable": "airflow.providers.common.compat.sdk", # Standard provider items with direct fallbacks "PythonOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"), "ShortCircuitOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"), "_SERIALIZERS": ("airflow.providers.standard.operators.python", "airflow.operators.python"), } +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk.bases.decorator import is_async_callable + from airflow.sdk.bases.operator import BaseAsyncOperator + from airflow.sdk.execution_time.callback_runner import ( + AsyncExecutionCallableRunner, + create_async_executable_runner, + ) + from airflow.sdk.types import OutletEventAccessorsProtocol +elif AIRFLOW_V_3_2_PLUS: + from airflow.sdk.bases.decorator import is_async_callable + from airflow.sdk.bases.operator import BaseAsyncOperator + from airflow.sdk.execution_time.callback_runner import ( + AsyncExecutionCallableRunner, + create_async_executable_runner, + ) +else: + import asyncio + import contextlib + import inspect + import logging + from asyncio import AbstractEventLoop + from collections.abc import AsyncIterator, Awaitable, Callable, Generator + from contextlib import suppress + from functools import partial + from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast + + from typing_extensions import ParamSpec + + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperator + from airflow.sdk.bases.decorator import _TaskDecorator + from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.mappedoperator import OperatorPartial + else: + from airflow.datasets.metadata import Metadata + from airflow.decorators.base import _TaskDecorator + from airflow.models import BaseOperator + from airflow.models.mappedoperator import OperatorPartial + + P = ParamSpec("P") + R = TypeVar("R") + + @contextlib.contextmanager + def event_loop() -> Generator[AbstractEventLoop]: + new_event_loop = False + loop = None + try: + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + raise RuntimeError + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + with contextlib.suppress(AttributeError): + loop.close() + asyncio.set_event_loop(None) + + def unwrap_partial(fn): + while isinstance(fn, partial): + fn = fn.func + return fn + + def unwrap_callable(func): + # Airflow-specific unwrap + if isinstance(func, (_TaskDecorator, OperatorPartial)): + func = getattr(func, "function", getattr(func, "_func", func)) + + # Unwrap functools.partial + func = unwrap_partial(func) + + # Unwrap @functools.wraps chains + with suppress(Exception): + func = inspect.unwrap(func) + + return func + + def is_async_callable(func): + """Detect if a callable (possibly wrapped) is an async function.""" + func = unwrap_callable(func) + + if not callable(func): + return False + + # Direct async function + if inspect.iscoroutinefunction(func): + return True + + # Callable object with async __call__ + if not inspect.isfunction(func): + call = type(func).__call__ # Bandit-safe + with suppress(Exception): + call = inspect.unwrap(call) + if inspect.iscoroutinefunction(call): + return True + + return False + + class _AsyncExecutionCallableRunner(Generic[P, R]): + @staticmethod + async def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type: ignore[empty-body] + + class AsyncExecutionCallableRunner(Protocol): + def __call__( + self, + func: Callable[P, R], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, + ) -> _AsyncExecutionCallableRunner[P, R]: ... + + def create_async_executable_runner( + func: Callable[P, Awaitable[R] | AsyncIterator], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, + ) -> _AsyncExecutionCallableRunner[P, R]: + """ + Run an async execution callable against a task context and given arguments. + + If the callable is a simple function, this simply calls it with the supplied + arguments (including the context). If the callable is a generator function, + the generator is exhausted here, with the yielded values getting fed back + into the task context automatically for execution. + + This convoluted implementation of inner class with closure is so *all* + arguments passed to ``run()`` can be forwarded to the wrapped function. This + is particularly important for the argument "self", which some use cases + need to receive. This is not possible if this is implemented as a normal + class, where "self" needs to point to the runner object, not the object + bounded to the inner callable. + + :meta private: + """ + + class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner): + @staticmethod + async def run(*args: P.args, **kwargs: P.kwargs) -> R: + if not inspect.isasyncgenfunction(func): + result = cast("Awaitable[R]", func(*args, **kwargs)) + return await result + + results: list[Any] = [] + + async for result in func(*args, **kwargs): + if isinstance(result, Metadata): + outlet_events[result.asset].extra.update(result.extra) + if result.alias: + outlet_events[result.alias].add(result.asset, extra=result.extra) + + results.append(result) + + return cast("R", results) + + return cast("_AsyncExecutionCallableRunner[P, R]", _AsyncExecutionCallableRunnerImpl) + + class BaseAsyncOperator(BaseOperator): + """ + Base class for async-capable operators. + + As opposed to deferred operators which are executed on the triggerer, async operators are executed + on the worker. + """ + + @property + def is_async(self) -> bool: + return True + + if not AIRFLOW_V_3_1_PLUS: + + @property + def xcom_push(self) -> bool: + return self.do_xcom_push + + @xcom_push.setter + def xcom_push(self, value: bool): + self.do_xcom_push = value + + async def aexecute(self, context): + """Async version of execute(). Subclasses should implement this.""" + raise NotImplementedError() + + def execute(self, context): + """Run `aexecute()` inside an event loop.""" + with event_loop() as loop: + if self.execution_timeout: + return loop.run_until_complete( + asyncio.wait_for( + self.aexecute(context), + timeout=self.execution_timeout.total_seconds(), + ) + ) + return loop.run_until_complete(self.aexecute(context)) + + __getattr__ = create_module_getattr(import_map=_IMPORT_MAP) __all__ = sorted(_IMPORT_MAP.keys()) diff --git a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py index 4142937bd2a9a..e3fd1e55f146e 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py +++ b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py @@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) +AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0) # BaseOperator removed from version_compat to avoid circular imports # Import it directly in files that need it instead @@ -41,4 +42,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: __all__ = [ "AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", + "AIRFLOW_V_3_2_PLUS", ] diff --git a/providers/standard/docs/operators/python.rst b/providers/standard/docs/operators/python.rst index 2e5e63ea437e2..091412509d841 100644 --- a/providers/standard/docs/operators/python.rst +++ b/providers/standard/docs/operators/python.rst @@ -72,6 +72,34 @@ Pass extra arguments to the ``@task`` decorated function as you would with a nor :start-after: [START howto_operator_python_kwargs] :end-before: [END howto_operator_python_kwargs] +Async Python functions +^^^^^^^^^^^^^^^^^^^^^^ + +From Airflow 3.2 onward, async Python callables are now also supported out of the box. +This means we don't need to cope with the event loop and allows us to easily invoke async Python code and async +Airflow hooks which are not always available through deferred operators. +As opposed to deferred operators which are executed on the triggerer, async operators are executed on the workers. + +.. tab-set:: + + .. tab-item:: @task + :sync: taskflow + + .. exampleinclude:: /../src/airflow/providers/standard/example_dags/example_python_decorator.py + :language: python + :dedent: 4 + :start-after: [START howto_async_operator_python_kwargs] + :end-before: [END howto_async_operator_python_kwargs] + + .. tab-item:: PythonOperator + :sync: operator + + .. exampleinclude:: /../src/airflow/providers/standard/example_dags/example_python_operator.py + :language: python + :dedent: 4 + :start-after: [START howto_async_operator_python_kwargs] + :end-before: [END howto_async_operator_python_kwargs] + Templating ^^^^^^^^^^ diff --git a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py index ac9938d92eac2..578a09f574fe3 100644 --- a/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py +++ b/providers/standard/src/airflow/providers/standard/example_dags/example_python_decorator.py @@ -22,6 +22,7 @@ from __future__ import annotations +import asyncio import logging import sys import time @@ -64,6 +65,7 @@ def log_sql(**kwargs): # [START howto_operator_python_kwargs] # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively + # Asynchronous callables are natively supported since Airflow 3.2+ @task def my_sleeping_function(random_base): """This is a function that will run within the DAG execution""" @@ -75,6 +77,22 @@ def my_sleeping_function(random_base): run_this >> log_the_sql >> sleeping_task # [END howto_operator_python_kwargs] + # [START howto_async_operator_python_kwargs] + # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively + # Asynchronous callables are natively supported since Airflow 3.2+ + @task + async def my_async_sleeping_function(random_base): + """This is a function that will run within the DAG execution""" + await asyncio.sleep(random_base) + + for i in range(5): + async_sleeping_task = my_async_sleeping_function.override(task_id=f"async_sleep_for_{i}")( + random_base=i / 10 + ) + + run_this >> log_the_sql >> async_sleeping_task + # [END howto_async_operator_python_kwargs] + # [START howto_operator_python_venv] @task.virtualenv( task_id="virtualenv_python", requirements=["colorama==0.4.0"], system_site_packages=False diff --git a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py index 18aa8f207e3b0..064ac0420253f 100644 --- a/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py +++ b/providers/standard/src/airflow/providers/standard/example_dags/example_python_operator.py @@ -22,6 +22,7 @@ from __future__ import annotations +import asyncio import logging import sys import time @@ -76,6 +77,7 @@ def log_sql(**kwargs): # [START howto_operator_python_kwargs] # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively + # Asynchronous callables are natively supported since Airflow 3.2+ def my_sleeping_function(random_base): """This is a function that will run within the DAG execution""" time.sleep(random_base) @@ -88,6 +90,23 @@ def my_sleeping_function(random_base): run_this >> log_the_sql >> sleeping_task # [END howto_operator_python_kwargs] + # [START howto_async_operator_python_kwargs] + # Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively + # Asynchronous callables are natively supported since Airflow 3.2+ + async def my_async_sleeping_function(random_base): + """This is a function that will run within the DAG execution""" + await asyncio.sleep(random_base) + + for i in range(5): + async_sleeping_task = PythonOperator( + task_id=f"async_sleep_for_{i}", + python_callable=my_async_sleeping_function, + op_kwargs={"random_base": i / 10}, + ) + + run_this >> log_the_sql >> async_sleeping_task + # [END howto_async_operator_python_kwargs] + # [START howto_operator_python_venv] def callable_virtualenv(): """ diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index ac8862f29230a..42fa7a063a4c2 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -48,13 +48,18 @@ ) from airflow.models.variable import Variable from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, context_merge +from airflow.providers.common.compat.standard.operators import ( + AsyncExecutionCallableRunner, + BaseAsyncOperator, + is_async_callable, +) from airflow.providers.standard.hooks.package_index import PackageIndexHook from airflow.providers.standard.utils.python_virtualenv import ( _execute_in_subprocess, prepare_virtualenv, write_python_script, ) -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import hashlib_wrapper from airflow.utils.file import get_unique_dag_module_name from airflow.utils.operator_helpers import KeywordParameters @@ -115,9 +120,9 @@ def from_executable(cls, executable: str) -> _PythonVersionInfo: return cls(*_parse_version_info(result.strip())) -class PythonOperator(BaseOperator): +class PythonOperator(BaseAsyncOperator): """ - Executes a Python callable. + Base class for all Python operators. .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -192,7 +197,14 @@ def __init__( self.template_ext = templates_exts self.show_return_value_in_logs = show_return_value_in_logs - def execute(self, context: Context) -> Any: + @property + def is_async(self) -> bool: + return is_async_callable(self.python_callable) + + def execute(self, context) -> Any: + if self.is_async: + return BaseAsyncOperator.execute(self, context) + context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = self.determine_kwargs(context) @@ -219,6 +231,40 @@ def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessors return return_value + async def aexecute(self, context): + """Async version of execute(). Subclasses should implement this.""" + context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) + self.op_kwargs = self.determine_kwargs(context) + + # This needs to be lazy because subclasses may implement execute_callable + # by running a separate process that can't use the eager result. + def __prepare_execution() -> tuple[AsyncExecutionCallableRunner, OutletEventAccessorsProtocol] | None: + from airflow.providers.common.compat.standard.operators import ( + create_async_executable_runner, + ) + + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.context import ( + context_get_outlet_events, + ) + else: + from airflow.utils.context import context_get_outlet_events # type: ignore + + return ( + cast("AsyncExecutionCallableRunner", create_async_executable_runner), + context_get_outlet_events(context), + ) + + self.__prepare_execution = __prepare_execution + + return_value = await self.aexecute_callable() + if self.show_return_value_in_logs: + self.log.info("Done. Returned value was: %s", return_value) + else: + self.log.info("Done. Returned value not shown") + + return return_value + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking() @@ -236,6 +282,18 @@ def execute_callable(self) -> Any: runner = create_execution_runner(self.python_callable, asset_events, logger=self.log) return runner.run(*self.op_args, **self.op_kwargs) + async def aexecute_callable(self) -> Any: + """ + Call the python callable with the given arguments. + + :return: the return value of the call. + """ + if (execution_preparation := self.__prepare_execution()) is None: + return await self.python_callable(*self.op_args, **self.op_kwargs) + create_execution_runner, asset_events = execution_preparation + runner = create_execution_runner(self.python_callable, asset_events, logger=self.log) + return await runner.run(*self.op_args, **self.op_kwargs) + class BranchPythonOperator(BaseBranchOperator, PythonOperator): """ diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 32b2fc2615d0a..d2aa38ec54434 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -37,8 +37,16 @@ from unit.standard.operators.test_python import BasePythonTest if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg, setup, task as task_decorator, teardown - from airflow.sdk.bases.decorator import DecoratedMappedOperator + from airflow.sdk import ( + DAG, + BaseOperator, + TaskGroup, + XComArg, + setup, + task as task_decorator, + teardown, + ) + from airflow.sdk.bases.decorator import DecoratedMappedOperator, _TaskDecorator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput else: from airflow.decorators import ( # type: ignore[attr-defined,no-redef] @@ -46,7 +54,7 @@ task as task_decorator, teardown, ) - from airflow.decorators.base import DecoratedMappedOperator # type: ignore[no-redef] + from airflow.decorators.base import DecoratedMappedOperator, _TaskDecorator # type: ignore[no-redef] from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.dag import DAG # type: ignore[assignment,no-redef] from airflow.models.expandinput import DictOfListsExpandInput # type: ignore[attr-defined,no-redef] @@ -658,9 +666,9 @@ def hello(): hello.override(pool="my_pool", priority_weight=i)() weights = [] - for task in self.dag_non_serialized.tasks: - assert task.pool == "my_pool" - weights.append(task.priority_weight) + for _task in self.dag_non_serialized.tasks: + assert _task.pool == "my_pool" + weights.append(_task.priority_weight) assert weights == [0, 1, 2] def test_python_callable_args_work_as_well_as_baseoperator_args(self, dag_maker): @@ -1142,3 +1150,19 @@ def my_teardown(): my_teardown() assert work_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS assert setup_task.operator.trigger_rule == TriggerRule.ONE_SUCCESS + + +async def async_fn(): + return 42 + + +def test_python_task(): + from airflow.providers.standard.decorators.python import _PythonDecoratedOperator, python_task + + decorator = python_task(async_fn) + + assert isinstance(decorator, _TaskDecorator) + assert decorator.function == async_fn + assert decorator.operator_class == _PythonDecoratedOperator + assert not decorator.multiple_outputs + assert decorator.kwargs == {"task_id": "async_fn"} diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index a59c33b29dcf7..e9f9babab8a19 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -17,7 +17,9 @@ # under the License. from __future__ import annotations +import asyncio import copy +import functools import logging import os import pickle @@ -43,7 +45,8 @@ from airflow.exceptions import AirflowProviderDeprecationWarning, DeserializingResultError from airflow.models.connection import Connection from airflow.models.taskinstance import TaskInstance, clear_task_instances -from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, task +from airflow.providers.common.compat.standard.operators import is_async_callable from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import ( BranchExternalPythonOperator, @@ -73,11 +76,9 @@ ) if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator from airflow.sdk.execution_time.context import set_current_context from airflow.serialization.serialized_objects import LazyDeserializedDAG else: - from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef] from airflow.models.taskinstance import set_current_context # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: @@ -2465,6 +2466,18 @@ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_ assert set(actual_skipped) == {op3} +class TestPythonAsyncOperator(TestPythonOperator): + def test_run_async_task(self, caplog): + caplog.set_level(logging.INFO, logger=LOGGER_NAME) + + async def say_hello(name: str) -> str: + await asyncio.sleep(1) + return f"Hello {name}!" + + self.run_as_task(say_hello, op_kwargs={"name": "world"}, show_return_value_in_logs=True) + assert "Done. Returned value was: Hello world!" in caplog.messages + + @pytest.mark.parametrize( ("text_input", "expected_tuple"), [ @@ -2521,3 +2534,141 @@ def test_python_version_info(mocker): assert result.releaselevel == sys.version_info.releaselevel assert result.serial == sys.version_info.serial assert list(result) == list(sys.version_info) + + +def simple_decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + return wrapper + + +def decorator_without_wraps(fn): + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + return wrapper + + +async def async_fn(): + return 42 + + +def sync_fn(): + return 42 + + +@simple_decorator +async def wrapped_async_fn(): + return 42 + + +@simple_decorator +def wrapped_sync_fn(): + return 42 + + +@decorator_without_wraps +async def wrapped_async_fn_no_wraps(): + return 42 + + +@simple_decorator +@simple_decorator +async def multi_wrapped_async_fn(): + return 42 + + +async def async_with_args(x, y): + return x + y + + +def sync_with_args(x, y): + return x + y + + +class AsyncCallable: + async def __call__(self): + return 42 + + +class SyncCallable: + def __call__(self): + return 42 + + +class WrappedAsyncCallable: + @simple_decorator + async def __call__(self): + return 42 + + +class TestAsyncCallable: + def test_plain_async_function(self): + assert is_async_callable(async_fn) + + def test_plain_sync_function(self): + assert not is_async_callable(sync_fn) + + def test_wrapped_async_function_with_wraps(self): + assert is_async_callable(wrapped_async_fn) + + def test_wrapped_sync_function_with_wraps(self): + assert not is_async_callable(wrapped_sync_fn) + + def test_wrapped_async_function_without_wraps(self): + """ + Without functools.wraps, inspect.unwrap cannot recover the coroutine. + This documents expected behavior. + """ + assert not is_async_callable(wrapped_async_fn_no_wraps) + + def test_multi_wrapped_async_function(self): + assert is_async_callable(multi_wrapped_async_fn) + + def test_partial_async_function(self): + fn = functools.partial(async_with_args, 1) + assert is_async_callable(fn) + + def test_partial_sync_function(self): + fn = functools.partial(sync_with_args, 1) + assert not is_async_callable(fn) + + def test_nested_partial_async_function(self): + fn = functools.partial( + functools.partial(async_with_args, 1), + 2, + ) + assert is_async_callable(fn) + + def test_async_callable_class(self): + assert is_async_callable(AsyncCallable()) + + def test_sync_callable_class(self): + assert not is_async_callable(SyncCallable()) + + def test_wrapped_async_callable_class(self): + assert is_async_callable(WrappedAsyncCallable()) + + def test_partial_callable_class(self): + fn = functools.partial(AsyncCallable()) + assert is_async_callable(fn) + + @pytest.mark.parametrize("value", [None, 42, "string", object()]) + def test_non_callable(self, value): + assert not is_async_callable(value) + + def test_task_decorator_async_function(self): + @task + async def async_task_fn(): + return 42 + + assert is_async_callable(async_task_fn) + + def test_task_decorator_sync_function(self): + @task + def sync_task_fn(): + return 42 + + assert not is_async_callable(sync_task_fn) diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 9c4d2b880b406..6e573e67ecc44 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -62,6 +62,8 @@ Task Decorators: Bases ----- +.. autoapiclass:: airflow.sdk.BaseAsyncOperator + .. autoapiclass:: airflow.sdk.BaseOperator .. autoapiclass:: airflow.sdk.BaseSensorOperator @@ -176,7 +178,7 @@ Everything else .. autoapimodule:: airflow.sdk :members: :special-members: __version__ - :exclude-members: BaseOperator, DAG, dag, asset, Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, get_current_context, get_parsing_context + :exclude-members: BaseAsyncOperator, BaseOperator, DAG, dag, asset, Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, TaskGroup, XComArg, get_current_context, get_parsing_context :undoc-members: :imported-members: :no-index: diff --git a/task-sdk/docs/index.rst b/task-sdk/docs/index.rst index 819f637676b44..f3258ea824319 100644 --- a/task-sdk/docs/index.rst +++ b/task-sdk/docs/index.rst @@ -78,6 +78,7 @@ Why use ``airflow.sdk``? **Classes** - :class:`airflow.sdk.Asset` +- :class:`airflow.sdk.BaseAsyncOperator` - :class:`airflow.sdk.BaseHook` - :class:`airflow.sdk.BaseNotifier` - :class:`airflow.sdk.BaseOperator` diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 4dbda282d086e..034a637943072 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -26,6 +26,7 @@ "AssetAny", "AssetOrTimeSchedule", "AssetWatcher", + "BaseAsyncOperator", "BaseHook", "BaseNotifier", "BaseOperator", @@ -76,7 +77,13 @@ from airflow.sdk.api.datamodels._generated import DagRunState, TaskInstanceState, TriggerRule, WeightRule from airflow.sdk.bases.hook import BaseHook from airflow.sdk.bases.notifier import BaseNotifier - from airflow.sdk.bases.operator import BaseOperator, chain, chain_linear, cross_downstream + from airflow.sdk.bases.operator import ( + BaseAsyncOperator, + BaseOperator, + chain, + chain_linear, + cross_downstream, + ) from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue from airflow.sdk.configuration import AirflowSDKConfigParser @@ -117,6 +124,7 @@ "AssetAny": ".definitions.asset", "AssetOrTimeSchedule": ".definitions.timetables.assets", "AssetWatcher": ".definitions.asset", + "BaseAsyncOperator": ".bases.operator", "BaseHook": ".bases.hook", "BaseNotifier": ".bases.notifier", "BaseOperator": ".bases.operator", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index eede7ff806a67..b035f49226c69 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -24,6 +24,7 @@ from airflow.sdk.api.datamodels._generated import ( from airflow.sdk.bases.hook import BaseHook as BaseHook from airflow.sdk.bases.notifier import BaseNotifier as BaseNotifier from airflow.sdk.bases.operator import ( + BaseAsyncOperator as BaseAsyncOperator, BaseOperator as BaseOperator, chain as chain, chain_linear as chain_linear, @@ -83,6 +84,7 @@ __all__ = [ "AssetAny", "AssetOrTimeSchedule", "AssetWatcher", + "BaseAsyncOperator", "BaseHook", "BaseNotifier", "BaseOperator", diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index bde898c169681..96768bc45051b 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -22,7 +22,8 @@ import textwrap import warnings from collections.abc import Callable, Collection, Iterator, Mapping, Sequence -from functools import cached_property, update_wrapper +from contextlib import suppress +from functools import cached_property, partial, update_wrapper from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol, TypeVar, cast, overload import attr @@ -149,6 +150,52 @@ def _find_id_suffixes(dag: DAG) -> Iterator[int]: return f"{core}__{max(_find_id_suffixes(dag)) + 1}" +def unwrap_partial(fn): + while isinstance(fn, partial): + fn = fn.func + return fn + + +def unwrap_callable(func): + from airflow.sdk.bases.decorator import _TaskDecorator + from airflow.sdk.definitions.mappedoperator import OperatorPartial + + # Airflow-specific unwrap + if isinstance(func, (_TaskDecorator, OperatorPartial)): + func = getattr(func, "function", getattr(func, "_func", func)) + + # Unwrap functools.partial + func = unwrap_partial(func) + + # Unwrap @functools.wraps chains + with suppress(Exception): + func = inspect.unwrap(func) + + return func + + +def is_async_callable(func): + """Detect if a callable (possibly wrapped) is an async function.""" + func = unwrap_callable(func) + + if not callable(func): + return False + + # Direct async function + if inspect.iscoroutinefunction(func): + return True + + # Callable object with async __call__ + if not inspect.isfunction(func): + call = type(func).__call__ # Bandit-safe + with suppress(Exception): + call = inspect.unwrap(call) + if inspect.iscoroutinefunction(call): + return True + + return False + + class DecoratedOperator(BaseOperator): """ Wraps a Python callable and captures args/kwargs when called for execution. @@ -243,6 +290,10 @@ def __init__( self.op_kwargs = op_kwargs super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs) + @property + def is_async(self) -> bool: + return is_async_callable(self.python_callable) + def execute(self, context: Context): # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators # as well diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 0c97df00ef0c5..5f07a1883623b 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -18,13 +18,15 @@ from __future__ import annotations import abc +import asyncio import collections.abc import contextlib import copy import inspect import sys import warnings -from collections.abc import Callable, Collection, Iterable, Mapping, Sequence +from asyncio import AbstractEventLoop +from collections.abc import Callable, Collection, Generator, Iterable, Mapping, Sequence from contextvars import ContextVar from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -101,6 +103,7 @@ def db_safe_priority(priority_weight: int) -> int: TaskPostExecuteHook = Callable[[Context, Any], None] __all__ = [ + "BaseAsyncOperator", "BaseOperator", "chain", "chain_linear", @@ -196,6 +199,27 @@ def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: return Resources(**resources) +@contextlib.contextmanager +def event_loop() -> Generator[AbstractEventLoop]: + new_event_loop = False + loop = None + try: + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + raise RuntimeError + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + with contextlib.suppress(AttributeError): + loop.close() + asyncio.set_event_loop(None) + + class _PartialDescriptor: """A descriptor that guards against ``.partial`` being called on Task objects.""" @@ -1670,6 +1694,35 @@ def has_on_skipped_callback(self) -> bool: return bool(self.on_skipped_callback) +class BaseAsyncOperator(BaseOperator): + """ + Base class for async-capable operators. + + As opposed to deferred operators which are executed on the triggerer, async operators are executed + on the worker. + """ + + @property + def is_async(self) -> bool: + return True + + async def aexecute(self, context): + """Async version of execute(). Subclasses should implement this.""" + raise NotImplementedError() + + def execute(self, context): + """Run `aexecute()` inside an event loop.""" + with event_loop() as loop: + if self.execution_timeout: + return loop.run_until_complete( + asyncio.wait_for( + self.aexecute(context), + timeout=self.execution_timeout.total_seconds(), + ) + ) + return loop.run_until_complete(self.aexecute(context)) + + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" Given a number of tasks, builds a dependency chain. diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 6c99a72b22080..e7e5ebe8b9ac4 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -137,6 +137,10 @@ class AbstractOperator(Templater, DAGNode): ) ) + @property + def is_async(self) -> bool: + return False + @property def task_type(self) -> str: raise NotImplementedError() diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py index 316c3d38e99b8..322e4bc97808a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_runner.py @@ -20,8 +20,8 @@ import inspect import logging -from collections.abc import Callable -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast from typing_extensions import ParamSpec @@ -39,6 +39,11 @@ class _ExecutionCallableRunner(Generic[P, R]): def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type: ignore[empty-body] +class _AsyncExecutionCallableRunner(Generic[P, R]): + @staticmethod + async def run(*args: P.args, **kwargs: P.kwargs) -> R: ... # type: ignore[empty-body] + + class ExecutionCallableRunner(Protocol): def __call__( self, @@ -49,6 +54,16 @@ def __call__( ) -> _ExecutionCallableRunner[P, R]: ... +class AsyncExecutionCallableRunner(Protocol): + def __call__( + self, + func: Callable[P, R], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, + ) -> _AsyncExecutionCallableRunner[P, R]: ... + + def create_executable_runner( func: Callable[P, R], outlet_events: OutletEventAccessorsProtocol, @@ -109,3 +124,51 @@ def _run(): return result # noqa: F821 # Ruff is not smart enough to know this is always set in _run(). return cast("_ExecutionCallableRunner[P, R]", _ExecutionCallableRunnerImpl) + + +def create_async_executable_runner( + func: Callable[P, Awaitable[R] | AsyncIterator], + outlet_events: OutletEventAccessorsProtocol, + *, + logger: logging.Logger | Logger, +) -> _AsyncExecutionCallableRunner[P, R]: + """ + Run an async execution callable against a task context and given arguments. + + If the callable is a simple function, this simply calls it with the supplied + arguments (including the context). If the callable is a generator function, + the generator is exhausted here, with the yielded values getting fed back + into the task context automatically for execution. + + This convoluted implementation of inner class with closure is so *all* + arguments passed to ``run()`` can be forwarded to the wrapped function. This + is particularly important for the argument "self", which some use cases + need to receive. This is not possible if this is implemented as a normal + class, where "self" needs to point to the runner object, not the object + bounded to the inner callable. + + :meta private: + """ + + class _AsyncExecutionCallableRunnerImpl(_AsyncExecutionCallableRunner): + @staticmethod + async def run(*args: P.args, **kwargs: P.kwargs) -> R: + from airflow.sdk.definitions.asset.metadata import Metadata + + if not inspect.isasyncgenfunction(func): + result = cast("Awaitable[R]", func(*args, **kwargs)) + return await result + + results: list[Any] = [] + + async for result in func(*args, **kwargs): + if isinstance(result, Metadata): + outlet_events[result.asset].extra.update(result.extra) + if result.alias: + outlet_events[result.alias].add(result.asset, extra=result.extra) + + results.append(result) + + return cast("R", results) + + return cast("_AsyncExecutionCallableRunner[P, R]", _AsyncExecutionCallableRunnerImpl) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 52a96d0b665ea..15755e640d97e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -48,7 +48,9 @@ from __future__ import annotations +import asyncio import itertools +import threading from collections.abc import Iterator from datetime import datetime from functools import cached_property @@ -185,31 +187,69 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + # Threading lock for sync operations + _thread_lock: threading.Lock = attrs.field(factory=threading.Lock, repr=False) + # Async lock for async operations + _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) + def send(self, msg: SendMsgType) -> ReceiveMsgType | None: """Send a request to the parent and block until the response is received.""" frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) frame_bytes = frame.as_bytes() - self.socket.sendall(frame_bytes) - if isinstance(msg, ResendLoggingFD): - if recv_fds is None: - return None - # We need special handling here! The server can't send us the fd number, as the number on the - # supervisor will be different to in this process, so we have to mutate the message ourselves here. - frame, fds = self._read_frame(maxfds=1) - resp = self._from_frame(frame) - if TYPE_CHECKING: - assert isinstance(resp, SentFDs) - resp.fds = fds - # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not - # always be in the return type union - return resp # type: ignore[return-value] + # We must make sure sockets aren't intermixed between sync and async calls, + # thus we need a dual locking mechanism to ensure that. + with self._thread_lock: + self.socket.sendall(frame_bytes) + if isinstance(msg, ResendLoggingFD): + if recv_fds is None: + return None + # We need special handling here! The server can't send us the fd number, as the number on the + # supervisor will be different to in this process, so we have to mutate the message ourselves here. + frame, fds = self._read_frame(maxfds=1) + resp = self._from_frame(frame) + if TYPE_CHECKING: + assert isinstance(resp, SentFDs) + resp.fds = fds + # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not + # always be in the return type union + return resp # type: ignore[return-value] return self._get_response() async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: - """Send a request to the parent without blocking.""" - raise NotImplementedError + """ + Send a request to the parent without blocking. + + Uses async lock for coroutine safety and thread lock for socket safety. + """ + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + frame_bytes = frame.as_bytes() + + async with self._async_lock: + # Acquire the threading lock without blocking the event loop + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._thread_lock.acquire) + try: + # Async write to socket + await loop.sock_sendall(self.socket, frame_bytes) + + if isinstance(msg, ResendLoggingFD): + if recv_fds is None: + return None + # Blocking read in a thread + frame, fds = await asyncio.to_thread(self._read_frame, maxfds=1) + resp = self._from_frame(frame) + if TYPE_CHECKING: + assert isinstance(resp, SentFDs) + resp.fds = fds + return resp # type: ignore[return-value] + + # Normal blocking read in a thread + frame = await asyncio.to_thread(self._read_frame) + return self._from_frame(frame) + finally: + self._thread_lock.release() @overload def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...