From f6796f77f4adfc42cd1608ca1b8a22cba4432685 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 15 Nov 2021 07:57:54 -0700 Subject: [PATCH] Preserve contextvars during comm offload (#5486) --- distributed/tests/test_utils.py | 13 +++++++++++++ distributed/utils.py | 7 ++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 0a698a3b3a4..f07975de8e7 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,5 +1,6 @@ import array import asyncio +import contextvars import functools import io import os @@ -554,6 +555,18 @@ async def test_offload(): assert (await offload(lambda x, y: x + y, 1, y=2)) == 3 +@pytest.mark.asyncio +async def test_offload_preserves_contextvars(): + var = contextvars.ContextVar("var") + + async def set_var(v: str): + var.set(v) + r = await offload(var.get) + assert r == v + + await asyncio.gather(set_var("foo"), set_var("bar")) + + def test_serialize_for_cli_deprecated(): with pytest.warns(FutureWarning, match="serialize_for_cli is deprecated"): from distributed.utils import serialize_for_cli diff --git a/distributed/utils.py b/distributed/utils.py index 19697804415..1cf51bdacb5 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import functools import importlib import inspect @@ -1326,7 +1327,11 @@ def import_term(name: str): async def offload(fn, *args, **kwargs): loop = asyncio.get_event_loop() - return await loop.run_in_executor(_offload_executor, lambda: fn(*args, **kwargs)) + # Retain context vars while deserializing; see https://bugs.python.org/issue34014 + context = contextvars.copy_context() + return await loop.run_in_executor( + _offload_executor, lambda: context.run(fn, *args, **kwargs) + ) class EmptyContext: