From 0c8fafb31df7ef3ef5812d6efb47ca342a3bad3c Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 26 Sep 2024 14:11:08 -0700 Subject: [PATCH] Fix SSR mode flag with `mount_gradio_app` and revert changes to pytests (#9446) * Revert "Fix Python unit tests on `5.0-dev` branch (#9432)" This reverts commit 278645b649fb590e6c9608c568ee0903c735a536. * revert changes to pytest * add changeset * fix --------- Co-authored-by: gradio-pr-bot --- .changeset/stupid-tires-stare.md | 5 ++ gradio/routes.py | 2 +- test/test_routes.py | 138 ++++++++++++++++++++----------- 3 files changed, 97 insertions(+), 48 deletions(-) create mode 100644 .changeset/stupid-tires-stare.md diff --git a/.changeset/stupid-tires-stare.md b/.changeset/stupid-tires-stare.md new file mode 100644 index 0000000000000..e57d563bd8425 --- /dev/null +++ b/.changeset/stupid-tires-stare.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Fix SSR mode flag with `mount_gradio_app` and revert changes to pytests diff --git a/gradio/routes.py b/gradio/routes.py index eaf39d6eeefdc..a79088457dce8 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -1567,7 +1567,7 @@ def read_main(): else ( ssr_mode if ssr_mode is not None - else bool(os.getenv("GRADIO_SSR_MODE", "False")) + else os.getenv("GRADIO_SSR_MODE", "False").lower() == "true" ) ) diff --git a/test/test_routes.py b/test/test_routes.py index 27c17371ed0de..d085405e193ec 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -4,7 +4,6 @@ import json import os import pickle -import socket import tempfile import time from contextlib import asynccontextmanager, closing @@ -19,7 +18,6 @@ import pytest import requests import starlette.routing -import uvicorn from fastapi import FastAPI, Request from fastapi.testclient import TestClient from gradio_client import media_data @@ -369,65 +367,97 @@ def test_get_file_created_by_app(self, test_client): assert len(file_response_with_partial_range.text) == 11 def test_mount_gradio_app(self): - @asynccontextmanager - async def empty_lifespan(app: FastAPI): - yield + app = FastAPI() - app = FastAPI(lifespan=empty_lifespan) + demo = gr.Interface( + lambda s: f"Hello from ps, {s}!", "textbox", "textbox" + ).queue() + demo1 = gr.Interface( + lambda s: f"Hello from py, {s}!", "textbox", "textbox" + ).queue() + + app = gr.mount_gradio_app(app, demo, path="/ps") + app = gr.mount_gradio_app(app, demo1, path="/py") + + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/ps").is_success + assert client.get("/py").is_success - demo1 = gr.Interface(lambda s: f"Hello 1, {s}!", "textbox", "textbox") - demo2 = gr.Interface(lambda s: f"Hello 2, {s}!", "textbox", "textbox") - demo3 = gr.Interface( - lambda s: f"Password-Protected Hello, {s}!", "textbox", "textbox" + def test_mount_gradio_app_with_app_kwargs(self): + app = FastAPI() + demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue() + app = gr.mount_gradio_app( + app, + demo, + path="/echo", + app_kwargs={"docs_url": "/docs-custom"}, ) + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/echo/docs-custom").is_success - app = gr.mount_gradio_app(app, demo1, path="/demo1") - app = gr.mount_gradio_app(app, demo2, path="/demo2") - app = gr.mount_gradio_app(app, demo3, path="/demo-auth", auth=("a", "b")) + def test_mount_gradio_app_with_auth_and_params(self): + app = FastAPI() + demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue() + app = gr.mount_gradio_app( + app, + demo, + path=f"{API_PREFIX}/echo", + auth=("a", "b"), + root_path=f"{API_PREFIX}/echo", + allowed_paths=["test/test_files/bus.png"], + ) + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get(f"{API_PREFIX}/echo/config").status_code == 401 + assert demo.root_path == f"{API_PREFIX}/echo" + assert demo.allowed_paths == ["test/test_files/bus.png"] + assert demo.show_error - def get_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to any free port - return s.getsockname()[1] # Get the port number + def test_mount_gradio_app_with_lifespan(self): + @asynccontextmanager + async def empty_lifespan(app: FastAPI): + yield - global port, server # noqa: PLW0603 - port = None - server = None + app = FastAPI(lifespan=empty_lifespan) - def run_server(): - global port, server # noqa: PLW0603 + demo = gr.Interface( + lambda s: f"Hello from ps, {s}!", "textbox", "textbox" + ).queue() + demo1 = gr.Interface( + lambda s: f"Hello from py, {s}!", "textbox", "textbox" + ).queue() - port = get_free_port() - config = uvicorn.Config(app, host="127.0.0.1", port=port) - server = uvicorn.Server(config) - server.run() + app = gr.mount_gradio_app(app, demo, path="/ps") + app = gr.mount_gradio_app(app, demo1, path="/py") - server_thread = Thread(target=run_server, daemon=True) - server_thread.start() + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/ps").is_success + assert client.get("/py").is_success - start_time = time.time() - while server is None: - time.sleep(0.01) - if time.time() - start_time > 3: - raise TimeoutError("Server did not start in time") + def test_mount_gradio_app_with_startup(self): + app = FastAPI() - base_url = f"http://127.0.0.1:{port}" + @app.on_event("startup") + async def empty_startup(): + return - # Test the main routes - assert requests.get(f"{base_url}/demo1").status_code == 200 - assert requests.get(f"{base_url}/demo2").status_code == 200 - assert requests.get(f"{base_url}/demo-non-existent").status_code == 404 + demo = gr.Interface( + lambda s: f"Hello from ps, {s}!", "textbox", "textbox" + ).queue() + demo1 = gr.Interface( + lambda s: f"Hello from py, {s}!", "textbox", "textbox" + ).queue() - # Test auth (TODO: Fix this) - assert ( - requests.get(f"{base_url}/demo-auth").status_code - != 200 # It should be 401, but it's 500 - ) - # requests.post(f"{base_url}/demo-auth/login", data={"username": "a", "password": "b"}) - # assert requests.get(f"{base_url}/demo-auth").status_code == 200 + app = gr.mount_gradio_app(app, demo, path="/ps") + app = gr.mount_gradio_app(app, demo1, path="/py") - server.should_exit = True # type: ignore - server_thread.join() + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/ps").is_success + assert client.get("/py").is_success def test_gradio_app_with_auth_dependency(self): def block_anonymous(request: Request): @@ -442,6 +472,20 @@ def block_anonymous(request: Request): assert not client.get("/", headers={}).is_success assert client.get("/", headers={"user": "abubakar"}).is_success + def test_mount_gradio_app_with_auth_dependency(self): + app = FastAPI() + + def get_user(request: Request): + return request.headers.get("user") + + demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox") + + app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user) + + with TestClient(app) as client: + assert client.get("/demo", headers={"user": "abubakar"}).is_success + assert not client.get("/demo").is_success + def test_static_file_missing(self, test_client): response = test_client.get(rf"{API_PREFIX}/static/not-here.js") assert response.status_code == 404