@@ -95,8 +95,7 @@ class CompletionRequest:
95
95
"""
96
96
97
97
model : str
98
- prompt : str
99
- messages : Optional [List [_AbstractMessage ]]
98
+ messages : List [_AbstractMessage ]
100
99
frequency_penalty : float = 0.0
101
100
temperature : float = 0.0
102
101
stop : Optional [List [str ]] = None
@@ -121,10 +120,10 @@ class CompletionChoice:
121
120
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122
121
"""
123
122
124
- finish_reason : str
125
123
index : int
126
124
message : AssistantMessage
127
- logprobs : Optional [List [Any ]]
125
+ finish_reason : str = None
126
+ logprobs : Optional [List [Any ]] = None
128
127
129
128
130
129
@dataclass
@@ -151,7 +150,7 @@ class CompletionResponse:
151
150
created : int
152
151
model : str
153
152
system_fingerprint : str
154
- usage : UsageStats
153
+ usage : Optional [ UsageStats ] = None
155
154
object : str = "chat.completion"
156
155
service_tier : Optional [str ] = None
157
156
@@ -220,8 +219,11 @@ def __init__(self, *args, **kwargs):
220
219
if self .draft_model is not None
221
220
else self .model .config .max_seq_length
222
221
)
222
+ self .system_fingerprint = (
223
+ self .builder_args .device + type (self .builder_args .precision ).__name__
224
+ )
223
225
224
- def completion (self , completion_request : CompletionRequest ):
226
+ def chunked_completion (self , completion_request : CompletionRequest ):
225
227
"""Handle a chat completion request and yield a chunked response.
226
228
227
229
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -246,13 +248,16 @@ def completion(self, completion_request: CompletionRequest):
246
248
247
249
# Initialize counters for chunk responses and encode the prompt.
248
250
id = str (uuid .uuid4 ())
251
+
249
252
idx = 0
250
253
buffer = []
251
254
encoded = self .encode_tokens (
252
- completion_request .prompt , bos = True , device = self .builder_args .device
255
+ completion_request .messages [- 1 ].get ("content" ),
256
+ bos = True ,
257
+ device = self .builder_args .device ,
253
258
)
254
259
generator_args = GeneratorArgs (
255
- completion_request .prompt ,
260
+ completion_request .messages [ - 1 ]. get ( "content" ) ,
256
261
encoded_prompt = encoded ,
257
262
chat_mode = False ,
258
263
)
@@ -302,21 +307,45 @@ def callback(x, *, done_generating=False):
302
307
choices = [choice_chunk ],
303
308
created = int (time .time ()),
304
309
model = completion_request .model ,
305
- system_fingerprint = uuid . UUID ( int = uuid . getnode ()) ,
310
+ system_fingerprint = self . system_fingerprint ,
306
311
)
307
312
yield chunk_response
308
313
self .start_pos += y .size (0 )
309
314
idx += 1
310
315
311
316
# Yield an ending chunk indicating the generation has completed.
312
- end_chunk = CompletionChoiceChunk (ChunkDelta (None , None , None ), idx , "eos" )
317
+ end_chunk = CompletionChoiceChunk (
318
+ ChunkDelta (None , None , None ), idx , finish_reason = "stop"
319
+ )
313
320
314
321
yield CompletionResponseChunk (
315
322
id = str (id ),
316
323
choices = [end_chunk ],
317
324
created = int (time .time ()),
318
325
model = completion_request .model ,
319
- system_fingerprint = uuid .UUID (int = uuid .getnode ()),
326
+ system_fingerprint = self .system_fingerprint ,
327
+ )
328
+
329
+ def sync_completion (self , request : CompletionRequest ):
330
+ """Handle a chat completion request and yield a single, non-chunked response"""
331
+ output = ""
332
+ for chunk in self .chunked_completion (request ):
333
+ if not chunk .choices [0 ].finish_reason :
334
+ output += chunk .choices [0 ].delta .content
335
+
336
+ message = AssistantMessage (content = output )
337
+ return CompletionResponse (
338
+ id = str (uuid .uuid4 ()),
339
+ choices = [
340
+ CompletionChoice (
341
+ finish_reason = "stop" ,
342
+ index = 0 ,
343
+ message = message ,
344
+ )
345
+ ],
346
+ created = int (time .time ()),
347
+ model = request .model ,
348
+ system_fingerprint = self .system_fingerprint ,
320
349
)
321
350
322
351
def _callback (self , x , * , buffer , done_generating ):
0 commit comments