|
9 | 9 | from dataclasses import dataclass as python_dataclass
|
10 | 10 | from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
|
11 | 11 | from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
|
| 12 | +from urllib.parse import urlparse |
12 | 13 |
|
13 | 14 | import fastapi
|
14 | 15 | import httpx
|
|
17 | 18 | from multipart.multipart import parse_options_header
|
18 | 19 | from starlette.datastructures import FormData, Headers, UploadFile
|
19 | 20 | from starlette.formparsers import MultiPartException, MultipartPart
|
| 21 | +from starlette.middleware.base import BaseHTTPMiddleware |
20 | 22 |
|
21 | 23 | from gradio import processing_utils, utils
|
22 | 24 | from gradio.data_classes import PredictBody
|
@@ -583,3 +585,58 @@ def starts_with_protocol(string: str) -> bool:
|
583 | 585 | """
|
584 | 586 | pattern = r"^[a-zA-Z][a-zA-Z0-9+\-.]*://"
|
585 | 587 | return re.match(pattern, string) is not None
|
| 588 | + |
| 589 | + |
| 590 | +def get_hostname(url: str) -> str: |
| 591 | + """ |
| 592 | + Returns the hostname of a given url, or an empty string if the url cannot be parsed. |
| 593 | + Examples: |
| 594 | + get_hostname("https://www.gradio.app") -> "www.gradio.app" |
| 595 | + get_hostname("localhost:7860") -> "localhost" |
| 596 | + get_hostname("127.0.0.1") -> "127.0.0.1" |
| 597 | + """ |
| 598 | + if not url: |
| 599 | + return "" |
| 600 | + if "://" not in url: |
| 601 | + url = "http://" + url |
| 602 | + try: |
| 603 | + return urlparse(url).hostname or "" |
| 604 | + except Exception: |
| 605 | + return "" |
| 606 | + |
| 607 | + |
| 608 | +class CustomCORSMiddleware(BaseHTTPMiddleware): |
| 609 | + async def dispatch(self, request: fastapi.Request, call_next): |
| 610 | + host: str = request.headers.get("host", "") |
| 611 | + origin: str = request.headers.get("origin", "") |
| 612 | + host_name = get_hostname(host) |
| 613 | + origin_name = get_hostname(origin) |
| 614 | + |
| 615 | + # Any of these hosts suggests that the Gradio app is running locally. |
| 616 | + # Note: "null" is a special case that happens if a Gradio app is running |
| 617 | + # as an embedded web component in a local static webpage. |
| 618 | + localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"] |
| 619 | + is_preflight = ( |
| 620 | + request.method == "OPTIONS" |
| 621 | + and "access-control-request-method" in request.headers |
| 622 | + ) |
| 623 | + |
| 624 | + if host_name in localhost_aliases and origin_name not in localhost_aliases: |
| 625 | + allow_origin_header = None |
| 626 | + else: |
| 627 | + allow_origin_header = origin |
| 628 | + |
| 629 | + if is_preflight: |
| 630 | + response = fastapi.Response() |
| 631 | + else: |
| 632 | + response = await call_next(request) |
| 633 | + |
| 634 | + if allow_origin_header: |
| 635 | + response.headers["Access-Control-Allow-Origin"] = allow_origin_header |
| 636 | + response.headers[ |
| 637 | + "Access-Control-Allow-Methods" |
| 638 | + ] = "GET, POST, PUT, DELETE, OPTIONS" |
| 639 | + response.headers[ |
| 640 | + "Access-Control-Allow-Headers" |
| 641 | + ] = "Origin, Content-Type, Accept" |
| 642 | + return response |
0 commit comments