8
8
import uuid
9
9
from abc import ABC
10
10
from dataclasses import dataclass
11
- from typing import Any , Dict , List , Optional
11
+ from typing import Any , Dict , List , Optional , Union
12
12
13
13
from build .utils import device_sync
14
14
@@ -86,6 +86,9 @@ class StreamOptions:
86
86
87
87
include_usage : bool = False
88
88
89
+ @dataclass
90
+ class ResponseFormat :
91
+ type : Optional [str ] = None
89
92
90
93
@dataclass
91
94
class CompletionRequest :
@@ -94,25 +97,27 @@ class CompletionRequest:
94
97
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
95
98
"""
96
99
100
+ messages : List [_AbstractMessage ]
97
101
model : str
98
- prompt : str
99
- messages : Optional [List [_AbstractMessage ]]
100
- frequency_penalty : float = 0.0
101
- temperature : float = 0.0
102
- stop : Optional [List [str ]] = None
103
- stream : bool = False
104
- stream_options : Optional [StreamOptions ] = None
105
- echo : bool = False
106
- frequency_penalty : float = 0.0
107
- guided_decode_json_schema : str = None
108
- guided_decode_json_schema_path : str = None
102
+ frequency_penalty : float = 0.0 # unimplemented
103
+ logit_bias : Optional [Dict [str , float ]] = None # unimplemented
104
+ logprobs : Optional [bool ] = None # unimplemented
105
+ top_logprobs : Optional [int ] = None # unimplemented
106
+ max_tokens : Optional [int ] = None # unimplemented
109
107
n : int = 1
110
- presence_penalty : float = 0
111
- logit_bias : Optional [Dict [str , float ]] = None
112
- logprobs : Optional [bool ] = None
113
- top_logprobs : Optional [int ] = None
114
- max_tokens : Optional [int ] = None
115
-
108
+ presence_penalty : float = 0 # unimplemented
109
+ response_format : Optional [ResponseFormat ] = None # unimplemented
110
+ seed : Optional [int ] = None # unimplemented
111
+ service_tier : Optional [str ] = None # unimplemented
112
+ stop : Optional [List [str ]] = None # unimplemented
113
+ stream : bool = False
114
+ stream_options : Optional [StreamOptions ] = None # unimplemented
115
+ temperature : Optional [float ] = 1.0 # unimplemented
116
+ top_p : Optional [float ] = 1.0 # unimplemented
117
+ tools : Optional [List [Any ]] = None # unimplemented
118
+ tool_choice : Optional [Union [str , Any ]] = None # unimplemented
119
+ parallel_tool_calls : Optional [bool ] = None # unimplemented
120
+ user : Optional [str ] = None # unimplemented
116
121
117
122
@dataclass
118
123
class CompletionChoice :
@@ -121,10 +126,10 @@ class CompletionChoice:
121
126
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122
127
"""
123
128
124
- finish_reason : str
125
129
index : int
126
130
message : AssistantMessage
127
- logprobs : Optional [List [Any ]]
131
+ finish_reason : str = None
132
+ logprobs : Optional [List [Any ]] = None
128
133
129
134
130
135
@dataclass
@@ -150,10 +155,10 @@ class CompletionResponse:
150
155
choices : List [CompletionChoice ]
151
156
created : int
152
157
model : str
153
- system_fingerprint : str
154
- usage : UsageStats
155
- object : str = "chat.completion"
158
+ system_fingerprint : str
156
159
service_tier : Optional [str ] = None
160
+ usage : Optional [UsageStats ] = None
161
+ object : str = "chat.completion"
157
162
158
163
159
164
@dataclass
@@ -193,8 +198,8 @@ class CompletionResponseChunk:
193
198
created : int
194
199
model : str
195
200
system_fingerprint : str
196
- object : str = "chat.completion.chunk"
197
201
service_tier : Optional [str ] = None
202
+ object : str = "chat.completion.chunk"
198
203
usage : Optional [UsageStats ] = None
199
204
200
205
@@ -220,8 +225,13 @@ def __init__(self, *args, **kwargs):
220
225
if self .draft_model is not None
221
226
else self .model .config .max_seq_length
222
227
)
228
+ # The System fingerprint is a unique identifier for the model and its configuration.
229
+ # Currently, this is not implemented in a
230
+ self .system_fingerprint = (
231
+ self .builder_args .device + type (self .builder_args .precision ).__name__
232
+ )
223
233
224
- def completion (self , completion_request : CompletionRequest ):
234
+ def chunked_completion (self , completion_request : CompletionRequest ):
225
235
"""Handle a chat completion request and yield a chunked response.
226
236
227
237
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +240,8 @@ def completion(self, completion_request: CompletionRequest):
230
240
- messages: The server consumes the final element of the array as the prompt.
231
241
- model: This has no impact on the server state, i.e. changing the model in the request
232
242
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233
- - temperature: This is used to control the randomness of the response. The server will use the temperature
243
+ - temperature: This is used to control the randomness of the response.
244
+ - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234
245
235
246
See https://github.com/pytorch/torchchat/issues/973 for more details.
236
247
@@ -246,13 +257,16 @@ def completion(self, completion_request: CompletionRequest):
246
257
247
258
# Initialize counters for chunk responses and encode the prompt.
248
259
id = str (uuid .uuid4 ())
260
+
249
261
idx = 0
250
262
buffer = []
251
263
encoded = self .encode_tokens (
252
- completion_request .prompt , bos = True , device = self .builder_args .device
264
+ completion_request .messages [- 1 ].get ("content" ),
265
+ bos = True ,
266
+ device = self .builder_args .device ,
253
267
)
254
268
generator_args = GeneratorArgs (
255
- completion_request .prompt ,
269
+ completion_request .messages [ - 1 ]. get ( "content" ) ,
256
270
encoded_prompt = encoded ,
257
271
chat_mode = False ,
258
272
)
@@ -302,21 +316,45 @@ def callback(x, *, done_generating=False):
302
316
choices = [choice_chunk ],
303
317
created = int (time .time ()),
304
318
model = completion_request .model ,
305
- system_fingerprint = uuid . UUID ( int = uuid . getnode ()) ,
319
+ system_fingerprint = self . system_fingerprint ,
306
320
)
307
321
yield chunk_response
308
322
self .start_pos += y .size (0 )
309
323
idx += 1
310
324
311
325
# Yield an ending chunk indicating the generation has completed.
312
- end_chunk = CompletionChoiceChunk (ChunkDelta (None , None , None ), idx , "eos" )
326
+ end_chunk = CompletionChoiceChunk (
327
+ ChunkDelta (None , None , None ), idx , finish_reason = "stop"
328
+ )
313
329
314
330
yield CompletionResponseChunk (
315
331
id = str (id ),
316
332
choices = [end_chunk ],
317
333
created = int (time .time ()),
318
334
model = completion_request .model ,
319
- system_fingerprint = uuid .UUID (int = uuid .getnode ()),
335
+ system_fingerprint = self .system_fingerprint ,
336
+ )
337
+
338
+ def sync_completion (self , request : CompletionRequest ):
339
+ """Handle a chat completion request and yield a single, non-chunked response"""
340
+ output = ""
341
+ for chunk in self .chunked_completion (request ):
342
+ if not chunk .choices [0 ].finish_reason :
343
+ output += chunk .choices [0 ].delta .content
344
+
345
+ message = AssistantMessage (content = output )
346
+ return CompletionResponse (
347
+ id = str (uuid .uuid4 ()),
348
+ choices = [
349
+ CompletionChoice (
350
+ finish_reason = "stop" ,
351
+ index = 0 ,
352
+ message = message ,
353
+ )
354
+ ],
355
+ created = int (time .time ()),
356
+ model = request .model ,
357
+ system_fingerprint = self .system_fingerprint ,
320
358
)
321
359
322
360
def _callback (self , x , * , buffer , done_generating ):
0 commit comments