Skip to content

Commit 975a817

Browse files
committed
JSON formatted response using OpenAI API types for server completion requests
1 parent 6303c8c commit 975a817

File tree

2 files changed

+93
-44
lines changed

2 files changed

+93
-44
lines changed

api/api.py

+44-12
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ class CompletionRequest:
9595
"""
9696

9797
model: str
98-
prompt: str
99-
messages: Optional[List[_AbstractMessage]]
98+
messages: List[_AbstractMessage]
10099
frequency_penalty: float = 0.0
101100
temperature: float = 0.0
102101
stop: Optional[List[str]] = None
@@ -121,10 +120,10 @@ class CompletionChoice:
121120
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122121
"""
123122

124-
finish_reason: str
125123
index: int
126124
message: AssistantMessage
127-
logprobs: Optional[List[Any]]
125+
finish_reason: str = None
126+
logprobs: Optional[List[Any]] = None
128127

129128

130129
@dataclass
@@ -151,7 +150,7 @@ class CompletionResponse:
151150
created: int
152151
model: str
153152
system_fingerprint: str
154-
usage: UsageStats
153+
usage: Optional[UsageStats] = None
155154
object: str = "chat.completion"
156155
service_tier: Optional[str] = None
157156

@@ -220,8 +219,13 @@ def __init__(self, *args, **kwargs):
220219
if self.draft_model is not None
221220
else self.model.config.max_seq_length
222221
)
222+
# The System fingerprint is a unique identifier for the model and its configuration.
223+
# Currently, this is not implemented in a
224+
self.system_fingerprint = (
225+
self.builder_args.device + type(self.builder_args.precision).__name__
226+
)
223227

224-
def completion(self, completion_request: CompletionRequest):
228+
def chunked_completion(self, completion_request: CompletionRequest):
225229
"""Handle a chat completion request and yield a chunked response.
226230
227231
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +234,8 @@ def completion(self, completion_request: CompletionRequest):
230234
- messages: The server consumes the final element of the array as the prompt.
231235
- model: This has no impact on the server state, i.e. changing the model in the request
232236
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233-
- temperature: This is used to control the randomness of the response. The server will use the temperature
237+
- temperature: This is used to control the randomness of the response.
238+
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234239
235240
See https://github.com/pytorch/torchchat/issues/973 for more details.
236241
@@ -246,13 +251,16 @@ def completion(self, completion_request: CompletionRequest):
246251

247252
# Initialize counters for chunk responses and encode the prompt.
248253
id = str(uuid.uuid4())
254+
249255
idx = 0
250256
buffer = []
251257
encoded = self.encode_tokens(
252-
completion_request.prompt, bos=True, device=self.builder_args.device
258+
completion_request.messages[-1].get("content"),
259+
bos=True,
260+
device=self.builder_args.device,
253261
)
254262
generator_args = GeneratorArgs(
255-
completion_request.prompt,
263+
completion_request.messages[-1].get("content"),
256264
encoded_prompt=encoded,
257265
chat_mode=False,
258266
)
@@ -302,21 +310,45 @@ def callback(x, *, done_generating=False):
302310
choices=[choice_chunk],
303311
created=int(time.time()),
304312
model=completion_request.model,
305-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
313+
system_fingerprint=self.system_fingerprint,
306314
)
307315
yield chunk_response
308316
self.start_pos += y.size(0)
309317
idx += 1
310318

311319
# Yield an ending chunk indicating the generation has completed.
312-
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
320+
end_chunk = CompletionChoiceChunk(
321+
ChunkDelta(None, None, None), idx, finish_reason="stop"
322+
)
313323

314324
yield CompletionResponseChunk(
315325
id=str(id),
316326
choices=[end_chunk],
317327
created=int(time.time()),
318328
model=completion_request.model,
319-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
329+
system_fingerprint=self.system_fingerprint,
330+
)
331+
332+
def sync_completion(self, request: CompletionRequest):
333+
"""Handle a chat completion request and yield a single, non-chunked response"""
334+
output = ""
335+
for chunk in self.chunked_completion(request):
336+
if not chunk.choices[0].finish_reason:
337+
output += chunk.choices[0].delta.content
338+
339+
message = AssistantMessage(content=output)
340+
return CompletionResponse(
341+
id=str(uuid.uuid4()),
342+
choices=[
343+
CompletionChoice(
344+
finish_reason="stop",
345+
index=0,
346+
message=message,
347+
)
348+
],
349+
created=int(time.time()),
350+
model=request.model,
351+
system_fingerprint=self.system_fingerprint,
320352
)
321353

322354
def _callback(self, x, *, buffer, done_generating):

server.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,36 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
7+
import json
8+
9+
from dataclasses import asdict
10+
from typing import Dict, List, Union
11+
12+
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
813

914
from build.builder import BuilderArgs, TokenizerArgs
10-
from flask import Flask, jsonify, request, Response
15+
from flask import Flask, request, Response
1116
from generate import GeneratorArgs
1217

18+
19+
"""
20+
Creates a flask app that can be used to serve the model as a chat API.
21+
"""
1322
app = Flask(__name__)
1423
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
1524
messages: list = []
1625
gen: OpenAiApiGenerator = None
1726

1827

28+
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29+
"""Recursively delete None values from a dictionary."""
30+
if type(d) is dict:
31+
return {k: _del_none(v) for k, v in d.items() if v}
32+
elif type(d) is list:
33+
return [_del_none(v) for v in d if v]
34+
return d
35+
36+
1937
@app.route("/chat", methods=["POST"])
2038
def chat_endpoint():
2139
"""
@@ -26,45 +44,44 @@ def chat_endpoint():
2644
2745
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
2846
47+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48+
a single CompletionResponse object will be returned.
2949
"""
30-
data = request.get_json()
31-
32-
# Add user message to chat history
33-
messages.append(data["messages"][-1])
34-
prompt = messages[-1]["content"]
3550

36-
# Generate the assistant response
37-
req = CompletionRequest(
38-
model=gen.builder_args.checkpoint_path,
39-
prompt=prompt,
40-
temperature=0,
41-
messages=[],
42-
)
51+
print(" === Completion Request ===")
4352

44-
response = ""
53+
# Parse the request in to a CompletionRequest object
54+
data = request.get_json()
55+
req = CompletionRequest(**data)
4556

46-
def unwrap(completion_generator):
47-
token_count = 0
48-
for chunk_response in completion_generator:
49-
content = chunk_response.choices[0].delta.content
50-
if not gen.is_llama3_model or content not in set(
51-
gen.tokenizer.special_tokens.keys()
52-
):
53-
yield content if content is not None else ""
54-
if content == gen.tokenizer.eos_id():
55-
yield "."
56-
token_count += 1
57+
# Add the user message to our internal message history.
58+
messages.append(UserMessage(**req.messages[-1]))
5759

5860
if data.get("stream") == "true":
59-
return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")
61+
62+
def chunk_processor(chunked_completion_generator):
63+
"""Inline function for postprocessing CompletionResponseChunk objects.
64+
65+
Here, we just jsonify the chunk and yield it as a string.
66+
"""
67+
messages.append(AssistantMessage(content=""))
68+
for chunk in chunked_completion_generator:
69+
if (next_tok := chunk.choices[0].delta.content) is None:
70+
next_tok = ""
71+
messages[-1].content += next_tok
72+
print(next_tok, end="")
73+
yield json.dumps(_del_none(asdict(chunk)))
74+
75+
return Response(
76+
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
77+
)
6078
else:
61-
for content in unwrap(gen.completion(req)):
62-
response += content
79+
response = gen.sync_completion(req)
6380

64-
# Add assistant response to chat history
65-
messages.append(AssistantMessage(content=response))
81+
messages.append(response.choices[0].message)
82+
print(messages[-1].content)
6683

67-
return jsonify({"response": response})
84+
return json.dumps(_del_none(asdict(response)))
6885

6986

7087
def initialize_generator(args) -> OpenAiApiGenerator:

0 commit comments

Comments
 (0)