|
17 | 17 |
|
18 | 18 |
|
19 | 19 | import re |
| 20 | +import threading |
20 | 21 | import time |
21 | 22 | from dataclasses import dataclass |
22 | 23 | from typing import Any, Optional, Tuple |
23 | 24 | from unittest import IsolatedAsyncioTestCase |
24 | 25 | from unittest.mock import AsyncMock, MagicMock, patch |
25 | 26 |
|
26 | 27 | import fastapi |
| 28 | +import httpx |
27 | 29 | import netaddr |
28 | 30 | import pydantic |
29 | 31 | import pytest |
| 32 | +import uvicorn |
30 | 33 | from fastapi.testclient import TestClient |
31 | 34 | from starlette.requests import Request |
32 | 35 |
|
33 | | -from bittensor.core.axon import AxonMiddleware, Axon |
| 36 | +from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer |
34 | 37 | from bittensor.core.errors import RunException |
35 | 38 | from bittensor.core.settings import version_as_int |
36 | 39 | from bittensor.core.stream import StreamingSynapse |
@@ -785,3 +788,44 @@ async def forward_fn(synapse: streaming_synapse_cls): |
785 | 788 | "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", |
786 | 789 | }, |
787 | 790 | ) |
| 791 | + |
| 792 | + |
| 793 | +def test_threaded_fastapi(): |
| 794 | + server_started = threading.Event() |
| 795 | + server_stopped = threading.Event() |
| 796 | + |
| 797 | + async def lifespan(app): |
| 798 | + server_started.set() |
| 799 | + yield |
| 800 | + server_stopped.set() |
| 801 | + |
| 802 | + app = fastapi.FastAPI( |
| 803 | + lifespan=lifespan, |
| 804 | + ) |
| 805 | + app.get("/")(lambda: "Hello World") |
| 806 | + |
| 807 | + server = FastAPIThreadedServer( |
| 808 | + uvicorn.Config( |
| 809 | + app, |
| 810 | + ), |
| 811 | + ) |
| 812 | + server.start() |
| 813 | + |
| 814 | + server_started.wait() |
| 815 | + |
| 816 | + assert server.is_running is True |
| 817 | + |
| 818 | + client = httpx.Client( |
| 819 | + base_url="http://127.0.0.1:8000", |
| 820 | + ) |
| 821 | + |
| 822 | + assert client.get("/").text == '"Hello World"' |
| 823 | + |
| 824 | + server.stop() |
| 825 | + |
| 826 | + assert server.should_exit is True |
| 827 | + |
| 828 | + server_stopped.wait() |
| 829 | + |
| 830 | + with pytest.raises(httpx.ConnectError): |
| 831 | + client.get("/") |
0 commit comments