8
8
import uuid
9
9
from abc import ABC
10
10
from dataclasses import dataclass
11
- from typing import Any , Dict , List , Optional
11
+ from typing import Any , Dict , List , Optional , Union
12
12
13
13
from build .utils import device_sync
14
14
@@ -87,31 +87,39 @@ class StreamOptions:
87
87
include_usage : bool = False
88
88
89
89
90
+ @dataclass
91
+ class ResponseFormat :
92
+ type : Optional [str ] = None
93
+
94
+
90
95
@dataclass
91
96
class CompletionRequest :
92
97
"""A full chat completion request.
93
98
94
99
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
95
100
"""
96
101
102
+ messages : List [_AbstractMessage ]
97
103
model : str
98
- prompt : str
99
- messages : Optional [List [_AbstractMessage ]]
100
- frequency_penalty : float = 0.0
101
- temperature : float = 0.0
102
- stop : Optional [List [str ]] = None
103
- stream : bool = False
104
- stream_options : Optional [StreamOptions ] = None
105
- echo : bool = False
106
- frequency_penalty : float = 0.0
107
- guided_decode_json_schema : str = None
108
- guided_decode_json_schema_path : str = None
104
+ frequency_penalty : float = 0.0 # unimplemented
105
+ logit_bias : Optional [Dict [str , float ]] = None # unimplemented
106
+ logprobs : Optional [bool ] = None # unimplemented
107
+ top_logprobs : Optional [int ] = None # unimplemented
108
+ max_tokens : Optional [int ] = None # unimplemented
109
109
n : int = 1
110
- presence_penalty : float = 0
111
- logit_bias : Optional [Dict [str , float ]] = None
112
- logprobs : Optional [bool ] = None
113
- top_logprobs : Optional [int ] = None
114
- max_tokens : Optional [int ] = None
110
+ presence_penalty : float = 0 # unimplemented
111
+ response_format : Optional [ResponseFormat ] = None # unimplemented
112
+ seed : Optional [int ] = None # unimplemented
113
+ service_tier : Optional [str ] = None # unimplemented
114
+ stop : Optional [List [str ]] = None # unimplemented
115
+ stream : bool = False
116
+ stream_options : Optional [StreamOptions ] = None # unimplemented
117
+ temperature : Optional [float ] = 1.0 # unimplemented
118
+ top_p : Optional [float ] = 1.0 # unimplemented
119
+ tools : Optional [List [Any ]] = None # unimplemented
120
+ tool_choice : Optional [Union [str , Any ]] = None # unimplemented
121
+ parallel_tool_calls : Optional [bool ] = None # unimplemented
122
+ user : Optional [str ] = None # unimplemented
115
123
116
124
117
125
@dataclass
@@ -121,10 +129,10 @@ class CompletionChoice:
121
129
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122
130
"""
123
131
124
- finish_reason : str
125
132
index : int
126
133
message : AssistantMessage
127
- logprobs : Optional [List [Any ]]
134
+ finish_reason : str = None
135
+ logprobs : Optional [List [Any ]] = None
128
136
129
137
130
138
@dataclass
@@ -151,9 +159,9 @@ class CompletionResponse:
151
159
created : int
152
160
model : str
153
161
system_fingerprint : str
154
- usage : UsageStats
155
- object : str = "chat.completion"
156
162
service_tier : Optional [str ] = None
163
+ usage : Optional [UsageStats ] = None
164
+ object : str = "chat.completion"
157
165
158
166
159
167
@dataclass
@@ -193,8 +201,8 @@ class CompletionResponseChunk:
193
201
created : int
194
202
model : str
195
203
system_fingerprint : str
196
- object : str = "chat.completion.chunk"
197
204
service_tier : Optional [str ] = None
205
+ object : str = "chat.completion.chunk"
198
206
usage : Optional [UsageStats ] = None
199
207
200
208
@@ -220,8 +228,13 @@ def __init__(self, *args, **kwargs):
220
228
if self .draft_model is not None
221
229
else self .model .config .max_seq_length
222
230
)
231
+ # The System fingerprint is a unique identifier for the model and its configuration.
232
+ # Currently, this is not implemented in a
233
+ self .system_fingerprint = (
234
+ self .builder_args .device + type (self .builder_args .precision ).__name__
235
+ )
223
236
224
- def completion (self , completion_request : CompletionRequest ):
237
+ def chunked_completion (self , completion_request : CompletionRequest ):
225
238
"""Handle a chat completion request and yield a chunked response.
226
239
227
240
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +243,8 @@ def completion(self, completion_request: CompletionRequest):
230
243
- messages: The server consumes the final element of the array as the prompt.
231
244
- model: This has no impact on the server state, i.e. changing the model in the request
232
245
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233
- - temperature: This is used to control the randomness of the response. The server will use the temperature
246
+ - temperature: This is used to control the randomness of the response.
247
+ - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234
248
235
249
See https://github.com/pytorch/torchchat/issues/973 for more details.
236
250
@@ -246,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):
246
260
247
261
# Initialize counters for chunk responses and encode the prompt.
248
262
id = str (uuid .uuid4 ())
263
+
249
264
idx = 0
250
265
buffer = []
251
266
encoded = self .encode_tokens (
252
- completion_request .prompt , bos = True , device = self .builder_args .device
267
+ completion_request .messages [- 1 ].get ("content" ),
268
+ bos = True ,
269
+ device = self .builder_args .device ,
253
270
)
254
271
generator_args = GeneratorArgs (
255
- completion_request .prompt ,
272
+ completion_request .messages [ - 1 ]. get ( "content" ) ,
256
273
encoded_prompt = encoded ,
257
274
chat_mode = False ,
258
275
)
@@ -302,21 +319,45 @@ def callback(x, *, done_generating=False):
302
319
choices = [choice_chunk ],
303
320
created = int (time .time ()),
304
321
model = completion_request .model ,
305
- system_fingerprint = uuid . UUID ( int = uuid . getnode ()) ,
322
+ system_fingerprint = self . system_fingerprint ,
306
323
)
307
324
yield chunk_response
308
325
self .start_pos += y .size (0 )
309
326
idx += 1
310
327
311
328
# Yield an ending chunk indicating the generation has completed.
312
- end_chunk = CompletionChoiceChunk (ChunkDelta (None , None , None ), idx , "eos" )
329
+ end_chunk = CompletionChoiceChunk (
330
+ ChunkDelta (None , None , None ), idx , finish_reason = "stop"
331
+ )
313
332
314
333
yield CompletionResponseChunk (
315
334
id = str (id ),
316
335
choices = [end_chunk ],
317
336
created = int (time .time ()),
318
337
model = completion_request .model ,
319
- system_fingerprint = uuid .UUID (int = uuid .getnode ()),
338
+ system_fingerprint = self .system_fingerprint ,
339
+ )
340
+
341
+ def sync_completion (self , request : CompletionRequest ):
342
+ """Handle a chat completion request and yield a single, non-chunked response"""
343
+ output = ""
344
+ for chunk in self .chunked_completion (request ):
345
+ if not chunk .choices [0 ].finish_reason :
346
+ output += chunk .choices [0 ].delta .content
347
+
348
+ message = AssistantMessage (content = output )
349
+ return CompletionResponse (
350
+ id = str (uuid .uuid4 ()),
351
+ choices = [
352
+ CompletionChoice (
353
+ finish_reason = "stop" ,
354
+ index = 0 ,
355
+ message = message ,
356
+ )
357
+ ],
358
+ created = int (time .time ()),
359
+ model = request .model ,
360
+ system_fingerprint = self .system_fingerprint ,
320
361
)
321
362
322
363
def _callback (self , x , * , buffer , done_generating ):
0 commit comments