Skip to content

Commit 34f9431

Browse files
freddyaboultongradio-pr-botabidlabs
authored
Python client properly handles hearbeat and log messages. Also handles responses longer than 65k (#6693)
* first commit * newlines * test * Fix depends * revert * add changeset * add changeset * Lint * queue full test * Add code * Update + fix * add changeset * Revert demo * Typo in success * Fix --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
1 parent a3cf90e commit 34f9431

File tree

9 files changed

+197
-70
lines changed

9 files changed

+197
-70
lines changed

.changeset/yummy-roses-decide.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@gradio/client": patch
3+
"gradio": patch
4+
"gradio_client": patch
5+
---
6+
7+
fix:Python client properly handles hearbeat and log messages. Also handles responses longer than 65k

client/js/src/client.ts

+21-4
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,16 @@ export function api_factory(
203203
} catch (e) {
204204
return [{ error: BROKEN_CONNECTION_MSG }, 500];
205205
}
206-
const output: PostResponse = await response.json();
207-
return [output, response.status];
206+
let output: PostResponse;
207+
let status: int;
208+
try {
209+
output = await response.json();
210+
status = response.status;
211+
} catch (e) {
212+
output = { error: `Could not parse server response: ${e}` };
213+
status = 500;
214+
}
215+
return [output, status];
208216
}
209217

210218
async function upload_files(
@@ -791,7 +799,17 @@ export function api_factory(
791799
},
792800
hf_token
793801
).then(([response, status]) => {
794-
if (status !== 200) {
802+
if (status === 503) {
803+
fire_event({
804+
type: "status",
805+
stage: "error",
806+
message: QUEUE_FULL_MSG,
807+
queue: true,
808+
endpoint: _endpoint,
809+
fn_index,
810+
time: new Date()
811+
});
812+
} else if (status !== 200) {
795813
fire_event({
796814
type: "status",
797815
stage: "error",
@@ -806,7 +824,6 @@ export function api_factory(
806824
if (!stream_open) {
807825
open_stream();
808826
}
809-
810827
let callback = async function (_data: object): void {
811828
const { type, status, data } = handle_message(
812829
_data,

client/python/gradio_client/client.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
Communicator,
3838
JobStatus,
3939
Message,
40+
QueueError,
41+
ServerMessage,
4042
Status,
4143
StatusUpdate,
4244
)
@@ -169,41 +171,38 @@ def __init__(
169171
async def stream_messages(self) -> None:
170172
try:
171173
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
172-
buffer = ""
173174
async with client.stream(
174175
"GET",
175176
self.sse_url,
176177
params={"session_hash": self.session_hash},
177178
headers=self.headers,
178179
cookies=self.cookies,
179180
) as response:
180-
async for line in response.aiter_text():
181-
buffer += line
182-
while "\n\n" in buffer:
183-
message, buffer = buffer.split("\n\n", 1)
184-
if message.startswith("data:"):
185-
resp = json.loads(message[5:])
186-
if resp["msg"] == "heartbeat":
187-
continue
188-
elif resp["msg"] == "server_stopped":
189-
for (
190-
pending_messages
191-
) in self.pending_messages_per_event.values():
192-
pending_messages.append(resp)
193-
return
194-
event_id = resp["event_id"]
195-
if event_id not in self.pending_messages_per_event:
196-
self.pending_messages_per_event[event_id] = []
197-
self.pending_messages_per_event[event_id].append(resp)
198-
if resp["msg"] == "process_completed":
199-
self.pending_event_ids.remove(event_id)
200-
if len(self.pending_event_ids) == 0:
201-
self.stream_open = False
202-
return
203-
elif message == "":
181+
async for line in response.aiter_lines():
182+
line = line.rstrip("\n")
183+
if not len(line):
184+
continue
185+
if line.startswith("data:"):
186+
resp = json.loads(line[5:])
187+
if resp["msg"] == ServerMessage.heartbeat:
204188
continue
205-
else:
206-
raise ValueError(f"Unexpected SSE line: '{message}'")
189+
elif resp["msg"] == ServerMessage.server_stopped:
190+
for (
191+
pending_messages
192+
) in self.pending_messages_per_event.values():
193+
pending_messages.append(resp)
194+
return
195+
event_id = resp["event_id"]
196+
if event_id not in self.pending_messages_per_event:
197+
self.pending_messages_per_event[event_id] = []
198+
self.pending_messages_per_event[event_id].append(resp)
199+
if resp["msg"] == ServerMessage.process_completed:
200+
self.pending_event_ids.remove(event_id)
201+
if len(self.pending_event_ids) == 0:
202+
self.stream_open = False
203+
return
204+
else:
205+
raise ValueError(f"Unexpected SSE line: '{line}'")
207206
except BaseException as e:
208207
import traceback
209208

@@ -218,6 +217,8 @@ async def send_data(self, data, hash_data):
218217
headers=self.headers,
219218
cookies=self.cookies,
220219
)
220+
if req.status_code == 503:
221+
raise QueueError("Queue is full! Please try again.")
221222
req.raise_for_status()
222223
resp = req.json()
223224
event_id = resp["event_id"]

client/python/gradio_client/utils.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ class SpaceDuplicationError(Exception):
102102
pass
103103

104104

105+
class ServerMessage(str, Enum):
106+
send_hash = "send_hash"
107+
queue_full = "queue_full"
108+
estimation = "estimation"
109+
send_data = "send_data"
110+
process_starts = "process_starts"
111+
process_generating = "process_generating"
112+
process_completed = "process_completed"
113+
log = "log"
114+
progress = "progress"
115+
heartbeat = "heartbeat"
116+
server_stopped = "server_stopped"
117+
118+
105119
class Status(Enum):
106120
"""Status codes presented to client users."""
107121

@@ -141,16 +155,17 @@ def __lt__(self, other: Status):
141155
def msg_to_status(msg: str) -> Status:
142156
"""Map the raw message from the backend to the status code presented to users."""
143157
return {
144-
"send_hash": Status.JOINING_QUEUE,
145-
"queue_full": Status.QUEUE_FULL,
146-
"estimation": Status.IN_QUEUE,
147-
"send_data": Status.SENDING_DATA,
148-
"process_starts": Status.PROCESSING,
149-
"process_generating": Status.ITERATING,
150-
"process_completed": Status.FINISHED,
151-
"progress": Status.PROGRESS,
152-
"log": Status.LOG,
153-
}[msg]
158+
ServerMessage.send_hash: Status.JOINING_QUEUE,
159+
ServerMessage.queue_full: Status.QUEUE_FULL,
160+
ServerMessage.estimation: Status.IN_QUEUE,
161+
ServerMessage.send_data: Status.SENDING_DATA,
162+
ServerMessage.process_starts: Status.PROCESSING,
163+
ServerMessage.process_generating: Status.ITERATING,
164+
ServerMessage.process_completed: Status.FINISHED,
165+
ServerMessage.progress: Status.PROGRESS,
166+
ServerMessage.log: Status.LOG,
167+
ServerMessage.server_stopped: Status.FINISHED,
168+
}[msg] # type: ignore
154169

155170

156171
@dataclass
@@ -436,9 +451,14 @@ async def stream_sse_v0(
436451
headers=headers,
437452
cookies=cookies,
438453
) as response:
439-
async for line in response.aiter_text():
454+
async for line in response.aiter_lines():
455+
line = line.rstrip("\n")
456+
if len(line) == 0:
457+
continue
440458
if line.startswith("data:"):
441459
resp = json.loads(line[5:])
460+
if resp["msg"] in [ServerMessage.log, ServerMessage.heartbeat]:
461+
continue
442462
with helper.lock:
443463
has_progress = "progress_data" in resp
444464
status_update = StatusUpdate(
@@ -502,7 +522,7 @@ async def stream_sse_v1(
502522

503523
with helper.lock:
504524
log_message = None
505-
if msg["msg"] == "log":
525+
if msg["msg"] == ServerMessage.log:
506526
log = msg.get("log")
507527
level = msg.get("level")
508528
if log and level:
@@ -527,13 +547,10 @@ async def stream_sse_v1(
527547
result = [e]
528548
helper.job.outputs.append(result)
529549
helper.job.latest_status = status_update
530-
531-
if msg["msg"] == "queue_full":
532-
raise QueueError("Queue is full! Please try again.")
533-
elif msg["msg"] == "process_completed":
550+
if msg["msg"] == ServerMessage.process_completed:
534551
del pending_messages_per_event[event_id]
535552
return msg["output"]
536-
elif msg["msg"] == "server_stopped":
553+
elif msg["msg"] == ServerMessage.server_stopped:
537554
raise ValueError("Server stopped.")
538555

539556
except asyncio.CancelledError:

client/python/test/conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,18 @@ def gradio_temp_dir(monkeypatch, tmp_path):
381381
"""
382382
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
383383
return tmp_path
384+
385+
386+
@pytest.fixture
387+
def long_response_with_info():
388+
def long_response(x):
389+
gr.Info("Beginning long response")
390+
time.sleep(17)
391+
gr.Info("Done!")
392+
return "\ta\nb" * 90000
393+
394+
return gr.Interface(
395+
long_response,
396+
None,
397+
gr.Textbox(label="Output"),
398+
)

client/python/test/test_client.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
import time
66
import uuid
7-
from concurrent.futures import CancelledError, TimeoutError
7+
from concurrent.futures import CancelledError, TimeoutError, wait
88
from contextlib import contextmanager
99
from datetime import datetime, timedelta
1010
from pathlib import Path
@@ -21,7 +21,13 @@
2121

2222
from gradio_client import Client
2323
from gradio_client.client import DEFAULT_TEMP_DIR
24-
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
24+
from gradio_client.utils import (
25+
Communicator,
26+
ProgressUnit,
27+
QueueError,
28+
Status,
29+
StatusUpdate,
30+
)
2531

2632
HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
2733

@@ -488,6 +494,40 @@ def test_return_layout_and_state_components(
488494
assert demo.predict(api_name="/close") == 4
489495
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)
490496

497+
def test_long_response_time_with_gr_info_and_big_payload(
498+
self, long_response_with_info
499+
):
500+
with connect(long_response_with_info) as demo:
501+
assert demo.predict(api_name="/predict") == "\ta\nb" * 90000
502+
503+
def test_queue_full_raises_error(self):
504+
demo = gr.Interface(lambda s: f"Hello {s}", "textbox", "textbox").queue(
505+
max_size=1
506+
)
507+
with connect(demo) as client:
508+
with pytest.raises(QueueError):
509+
job1 = client.submit("Freddy", api_name="/predict")
510+
job2 = client.submit("Abubakar", api_name="/predict")
511+
job3 = client.submit("Pete", api_name="/predict")
512+
wait([job1, job2, job3])
513+
job1.result()
514+
job2.result()
515+
job3.result()
516+
517+
def test_json_parse_error(self):
518+
data = (
519+
"Bonjour Olivier, tu as l'air bien r\u00e9veill\u00e9 ce matin. Tu veux que je te pr\u00e9pare tes petits-d\u00e9j.\n",
520+
None,
521+
)
522+
523+
def return_bad():
524+
return data
525+
526+
demo = gr.Interface(return_bad, None, ["text", "text"])
527+
with connect(demo) as client:
528+
pred = client.predict(api_name="/predict")
529+
assert pred[0] == data[0]
530+
491531

492532
class TestStatusUpdates:
493533
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")

0 commit comments

Comments
 (0)