Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ContextVar instead of threading.local() #5625

Merged
merged 9 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/chatty-adults-reply.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Use ContextVar instead of threading.local()
10 changes: 7 additions & 3 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@


def in_event_listener():
from gradio import context
from gradio.context import LocalContext

return getattr(context.thread_data, "in_event_listener", False)
return LocalContext.in_event_listener.get()


def updateable(fn):
Expand Down Expand Up @@ -1135,7 +1135,11 @@ async def call_function(
start = time.time()

fn = utils.get_function_with_locals(
block_fn.fn, self, event_id, in_event_listener
fn=block_fn.fn,
blocks=self,
event_id=event_id,
in_event_listener=in_event_listener,
request=request,
)

if iterator is None: # If not a generator function that has already run
Expand Down
9 changes: 7 additions & 2 deletions gradio/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import threading
from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import BlockContext, Blocks
from gradio.routes import Request


class Context:
Expand All @@ -17,4 +18,8 @@ class Context:
hf_token: str | None = None # The token provided when loading private HF repos


thread_data = threading.local()
class LocalContext:
blocks: ContextVar[Blocks | None] = ContextVar("blocks", default=None)
in_event_listener: ContextVar[bool] = ContextVar("in_event_listener", default=False)
event_id: ContextVar[str | None] = ContextVar("event_id", default=None)
request: ContextVar[Request | None] = ContextVar("request", default=None)
13 changes: 7 additions & 6 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,22 +1098,23 @@ def __init__(self, target: Block | None, _data: Any):


def log_message(message: str, level: Literal["info", "warning"] = "info"):
from gradio import context
from gradio.context import LocalContext

if not hasattr(context.thread_data, "blocks"): # Function called outside of Gradio
blocks = LocalContext.blocks.get()
if blocks is None: # Function called outside of Gradio
if level == "info":
print(message)
elif level == "warning":
warnings.warn(message)
return
if not context.thread_data.blocks.enable_queue:
if not blocks.enable_queue:
warnings.warn(
f"Queueing must be enabled to issue {level.capitalize()}: '{message}'."
)
return
context.thread_data.blocks._queue.log_message(
event_id=context.thread_data.event_id, log=message, level=level
)
event_id = LocalContext.event_id.get()
assert event_id
blocks._queue.log_message(event_id=event_id, log=message, level=level)


@document()
Expand Down
22 changes: 14 additions & 8 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import Block, BlockContext, Blocks
from gradio.components import Component
from gradio.routes import App
from gradio.routes import App, Request

JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")

Expand Down Expand Up @@ -660,19 +660,25 @@ def wrapper(*args, **kwargs):


def get_function_with_locals(
fn: Callable, blocks: Blocks, event_id: str | None, in_event_listener: bool
fn: Callable,
blocks: Blocks,
event_id: str | None,
in_event_listener: bool,
request: Request | None,
):
def before_fn(blocks, event_id):
from gradio.context import thread_data
from gradio.context import LocalContext

thread_data.blocks = blocks
thread_data.in_event_listener = in_event_listener
thread_data.event_id = event_id
LocalContext.blocks.set(blocks)
LocalContext.in_event_listener.set(in_event_listener)
LocalContext.event_id.set(event_id)
LocalContext.request.set(request)

def after_fn():
from gradio.context import thread_data
from gradio.context import LocalContext

thread_data.in_event_listener = False
LocalContext.in_event_listener.set(False)
LocalContext.request.set(None)

return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
Expand Down
44 changes: 44 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import os
import shutil
Expand Down Expand Up @@ -756,3 +757,46 @@ def greet(s):
["Letter c", "info"],
["Too short!", "warning"],
]


@pytest.mark.asyncio
@pytest.mark.parametrize("async_handler", [True, False])
async def test_info_isolation(async_handler: bool):
async def greet_async(name):
await asyncio.sleep(2)
gr.Info(f"Hello {name}")
return name

def greet_sync(name):
time.sleep(2)
gr.Info(f"Hello {name}")
return name

demo = gr.Interface(greet_async if async_handler else greet_sync, "text", "text")
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)

async def session_interaction(name, delay=0):
await asyncio.sleep(delay)
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
log_messages = []
while True:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [name], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": name}))
if msg["msg"] == "log":
log_messages.append(msg["log"])
if msg["msg"] == "process_completed":
break
return log_messages

alice_logs, bob_logs = await asyncio.gather(
session_interaction("Alice"),
session_interaction("Bob", delay=1),
)

assert alice_logs == ["Hello Alice"]
assert bob_logs == ["Hello Bob"]
Loading