@@ -102,6 +102,20 @@ class SpaceDuplicationError(Exception):
102
102
pass
103
103
104
104
105
+ class ServerMessage (str , Enum ):
106
+ send_hash = "send_hash"
107
+ queue_full = "queue_full"
108
+ estimation = "estimation"
109
+ send_data = "send_data"
110
+ process_starts = "process_starts"
111
+ process_generating = "process_generating"
112
+ process_completed = "process_completed"
113
+ log = "log"
114
+ progress = "progress"
115
+ heartbeat = "heartbeat"
116
+ server_stopped = "server_stopped"
117
+
118
+
105
119
class Status (Enum ):
106
120
"""Status codes presented to client users."""
107
121
@@ -141,16 +155,17 @@ def __lt__(self, other: Status):
141
155
def msg_to_status (msg : str ) -> Status :
142
156
"""Map the raw message from the backend to the status code presented to users."""
143
157
return {
144
- "send_hash" : Status .JOINING_QUEUE ,
145
- "queue_full" : Status .QUEUE_FULL ,
146
- "estimation" : Status .IN_QUEUE ,
147
- "send_data" : Status .SENDING_DATA ,
148
- "process_starts" : Status .PROCESSING ,
149
- "process_generating" : Status .ITERATING ,
150
- "process_completed" : Status .FINISHED ,
151
- "progress" : Status .PROGRESS ,
152
- "log" : Status .LOG ,
153
- }[msg ]
158
+ ServerMessage .send_hash : Status .JOINING_QUEUE ,
159
+ ServerMessage .queue_full : Status .QUEUE_FULL ,
160
+ ServerMessage .estimation : Status .IN_QUEUE ,
161
+ ServerMessage .send_data : Status .SENDING_DATA ,
162
+ ServerMessage .process_starts : Status .PROCESSING ,
163
+ ServerMessage .process_generating : Status .ITERATING ,
164
+ ServerMessage .process_completed : Status .FINISHED ,
165
+ ServerMessage .progress : Status .PROGRESS ,
166
+ ServerMessage .log : Status .LOG ,
167
+ ServerMessage .server_stopped : Status .FINISHED ,
168
+ }[msg ] # type: ignore
154
169
155
170
156
171
@dataclass
@@ -436,9 +451,14 @@ async def stream_sse_v0(
436
451
headers = headers ,
437
452
cookies = cookies ,
438
453
) as response :
439
- async for line in response .aiter_text ():
454
+ async for line in response .aiter_lines ():
455
+ line = line .rstrip ("\n " )
456
+ if len (line ) == 0 :
457
+ continue
440
458
if line .startswith ("data:" ):
441
459
resp = json .loads (line [5 :])
460
+ if resp ["msg" ] in [ServerMessage .log , ServerMessage .heartbeat ]:
461
+ continue
442
462
with helper .lock :
443
463
has_progress = "progress_data" in resp
444
464
status_update = StatusUpdate (
@@ -502,7 +522,7 @@ async def stream_sse_v1(
502
522
503
523
with helper .lock :
504
524
log_message = None
505
- if msg ["msg" ] == " log" :
525
+ if msg ["msg" ] == ServerMessage . log :
506
526
log = msg .get ("log" )
507
527
level = msg .get ("level" )
508
528
if log and level :
@@ -527,13 +547,10 @@ async def stream_sse_v1(
527
547
result = [e ]
528
548
helper .job .outputs .append (result )
529
549
helper .job .latest_status = status_update
530
-
531
- if msg ["msg" ] == "queue_full" :
532
- raise QueueError ("Queue is full! Please try again." )
533
- elif msg ["msg" ] == "process_completed" :
550
+ if msg ["msg" ] == ServerMessage .process_completed :
534
551
del pending_messages_per_event [event_id ]
535
552
return msg ["output" ]
536
- elif msg ["msg" ] == " server_stopped" :
553
+ elif msg ["msg" ] == ServerMessage . server_stopped :
537
554
raise ValueError ("Server stopped." )
538
555
539
556
except asyncio .CancelledError :
0 commit comments