Skip to content

Commit 84802ee

Browse files
Tighten CORS rules (#7503)
* tighten cors rules * add changeset * cors policy * cors * add changeset * lint * changes * changes * changes * logging * add null * changes * changes * options * options * safe changes * let browser enforce cors * clean * route utils * fix * fix test * fix --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent b186767 commit 84802ee

File tree

4 files changed

+82
-7
lines changed

4 files changed

+82
-7
lines changed

.changeset/olive-symbols-heal.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
feat:Tighten CORS rules

gradio/route_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclasses import dataclass as python_dataclass
1010
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
1111
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
12+
from urllib.parse import urlparse
1213

1314
import fastapi
1415
import httpx
@@ -17,6 +18,7 @@
1718
from multipart.multipart import parse_options_header
1819
from starlette.datastructures import FormData, Headers, UploadFile
1920
from starlette.formparsers import MultiPartException, MultipartPart
21+
from starlette.middleware.base import BaseHTTPMiddleware
2022

2123
from gradio import processing_utils, utils
2224
from gradio.data_classes import PredictBody
@@ -583,3 +585,58 @@ def starts_with_protocol(string: str) -> bool:
583585
"""
584586
pattern = r"^[a-zA-Z][a-zA-Z0-9+\-.]*://"
585587
return re.match(pattern, string) is not None
588+
589+
590+
def get_hostname(url: str) -> str:
591+
"""
592+
Returns the hostname of a given url, or an empty string if the url cannot be parsed.
593+
Examples:
594+
get_hostname("https://www.gradio.app") -> "www.gradio.app"
595+
get_hostname("localhost:7860") -> "localhost"
596+
get_hostname("127.0.0.1") -> "127.0.0.1"
597+
"""
598+
if not url:
599+
return ""
600+
if "://" not in url:
601+
url = "http://" + url
602+
try:
603+
return urlparse(url).hostname or ""
604+
except Exception:
605+
return ""
606+
607+
608+
class CustomCORSMiddleware(BaseHTTPMiddleware):
609+
async def dispatch(self, request: fastapi.Request, call_next):
610+
host: str = request.headers.get("host", "")
611+
origin: str = request.headers.get("origin", "")
612+
host_name = get_hostname(host)
613+
origin_name = get_hostname(origin)
614+
615+
# Any of these hosts suggests that the Gradio app is running locally.
616+
# Note: "null" is a special case that happens if a Gradio app is running
617+
# as an embedded web component in a local static webpage.
618+
localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
619+
is_preflight = (
620+
request.method == "OPTIONS"
621+
and "access-control-request-method" in request.headers
622+
)
623+
624+
if host_name in localhost_aliases and origin_name not in localhost_aliases:
625+
allow_origin_header = None
626+
else:
627+
allow_origin_header = origin
628+
629+
if is_preflight:
630+
response = fastapi.Response()
631+
else:
632+
response = await call_next(request)
633+
634+
if allow_origin_header:
635+
response.headers["Access-Control-Allow-Origin"] = allow_origin_header
636+
response.headers[
637+
"Access-Control-Allow-Methods"
638+
] = "GET, POST, PUT, DELETE, OPTIONS"
639+
response.headers[
640+
"Access-Control-Allow-Headers"
641+
] = "Origin, Content-Type, Accept"
642+
return response

gradio/routes.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import markupsafe
3030
import orjson
3131
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, status
32-
from fastapi.middleware.cors import CORSMiddleware
3332
from fastapi.responses import (
3433
FileResponse,
3534
HTMLResponse,
@@ -55,6 +54,7 @@
5554
from gradio.processing_utils import add_root_url
5655
from gradio.queueing import Estimation
5756
from gradio.route_utils import ( # noqa: F401
57+
CustomCORSMiddleware,
5858
FileUploadProgress,
5959
FileUploadProgressNotQueuedError,
6060
FileUploadProgressNotTrackedError,
@@ -196,12 +196,7 @@ def create_app(
196196
app.configure_app(blocks)
197197

198198
if not wasm_utils.IS_WASM:
199-
app.add_middleware(
200-
CORSMiddleware,
201-
allow_origins=["*"],
202-
allow_methods=["*"],
203-
allow_headers=["*"],
204-
)
199+
app.add_middleware(CustomCORSMiddleware)
205200

206201
@app.get("/user")
207202
@app.get("/user/")

test/test_routes.py

+18
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,24 @@ def test_can_get_config_that_includes_non_pickle_able_objects(self):
462462
response = client.get("/config/")
463463
assert response.is_success
464464

465+
def test_cors_restrictions(self):
466+
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
467+
app, _, _ = io.launch(prevent_thread_lock=True)
468+
client = TestClient(app)
469+
custom_headers = {
470+
"host": "localhost:7860",
471+
"origin": "https://example.com",
472+
}
473+
file_response = client.get("/config", headers=custom_headers)
474+
assert "access-control-allow-origin" not in file_response.headers
475+
custom_headers = {
476+
"host": "localhost:7860",
477+
"origin": "127.0.0.1",
478+
}
479+
file_response = client.get("/config", headers=custom_headers)
480+
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
481+
io.close()
482+
465483

466484
class TestApp:
467485
def test_create_app(self):

0 commit comments

Comments
 (0)