diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index 0fa1cf6d50592..b924bc4384c5e 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -175,6 +175,48 @@ def _disable_ol_plugin(): airflow.plugins_manager.plugins = None +@pytest.fixture(autouse=True) +def _cleanup_async_resources(request): + """ + Clean up async resources that can cause Python 3.12 fork warnings. + + Problem: asgiref.sync.sync_to_async (used in _async_get_connection) creates + ThreadPoolExecutors that persist between tests. When supervisor.py calls + os.fork() in subsequent tests, Python 3.12+ warns about forking a + multi-threaded process. + + Solution: Clean up asgiref's ThreadPoolExecutors after async tests to ensure + subsequent tests start with a clean thread environment. + """ + yield + + # Only clean up after async tests to avoid unnecessary overhead + if "asyncio" in request.keywords: + # Clean up asgiref ThreadPoolExecutors that persist between tests + # These are created by sync_to_async() calls in async connection retrieval + try: + from asgiref.sync import SyncToAsync + + # SyncToAsync maintains a class-level executor for performance + # We need to shut it down to prevent multi-threading warnings on fork() + if hasattr(SyncToAsync, "single_thread_executor") and SyncToAsync.single_thread_executor: + if not SyncToAsync.single_thread_executor._shutdown: + SyncToAsync.single_thread_executor.shutdown(wait=True) + SyncToAsync.single_thread_executor = None + + # SyncToAsync also maintains a WeakKeyDictionary of context-specific executors + # Clean these up too to ensure complete thread cleanup + if hasattr(SyncToAsync, "context_to_thread_executor"): + for executor in list(SyncToAsync.context_to_thread_executor.values()): + if hasattr(executor, "shutdown") and not getattr(executor, "_shutdown", True): + executor.shutdown(wait=True) + SyncToAsync.context_to_thread_executor.clear() + + except (ImportError, AttributeError): + # If asgiref structure changes, fail gracefully + pass + + class MakeTIContextCallable(Protocol): def __call__( self,