Skip to content

Commit 8472f6d

Browse files
committed
JSON formatted response using OpenAI API types for server completion requests
1 parent 6303c8c commit 8472f6d

File tree

2 files changed

+88
-43
lines changed

2 files changed

+88
-43
lines changed

api/api.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ class CompletionRequest:
9595
"""
9696

9797
model: str
98-
prompt: str
99-
messages: Optional[List[_AbstractMessage]]
98+
messages: List[_AbstractMessage]
10099
frequency_penalty: float = 0.0
101100
temperature: float = 0.0
102101
stop: Optional[List[str]] = None
@@ -121,10 +120,10 @@ class CompletionChoice:
121120
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122121
"""
123122

124-
finish_reason: str
125123
index: int
126124
message: AssistantMessage
127-
logprobs: Optional[List[Any]]
125+
finish_reason: str = None
126+
logprobs: Optional[List[Any]] = None
128127

129128

130129
@dataclass
@@ -151,7 +150,7 @@ class CompletionResponse:
151150
created: int
152151
model: str
153152
system_fingerprint: str
154-
usage: UsageStats
153+
usage: Optional[UsageStats] = None
155154
object: str = "chat.completion"
156155
service_tier: Optional[str] = None
157156

@@ -220,8 +219,11 @@ def __init__(self, *args, **kwargs):
220219
if self.draft_model is not None
221220
else self.model.config.max_seq_length
222221
)
222+
self.system_fingerprint = (
223+
self.builder_args.device + type(self.builder_args.precision).__name__
224+
)
223225

224-
def completion(self, completion_request: CompletionRequest):
226+
def chunked_completion(self, completion_request: CompletionRequest):
225227
"""Handle a chat completion request and yield a chunked response.
226228
227229
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -246,13 +248,16 @@ def completion(self, completion_request: CompletionRequest):
246248

247249
# Initialize counters for chunk responses and encode the prompt.
248250
id = str(uuid.uuid4())
251+
249252
idx = 0
250253
buffer = []
251254
encoded = self.encode_tokens(
252-
completion_request.prompt, bos=True, device=self.builder_args.device
255+
completion_request.messages[-1].get("content"),
256+
bos=True,
257+
device=self.builder_args.device,
253258
)
254259
generator_args = GeneratorArgs(
255-
completion_request.prompt,
260+
completion_request.messages[-1].get("content"),
256261
encoded_prompt=encoded,
257262
chat_mode=False,
258263
)
@@ -302,21 +307,45 @@ def callback(x, *, done_generating=False):
302307
choices=[choice_chunk],
303308
created=int(time.time()),
304309
model=completion_request.model,
305-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
310+
system_fingerprint=self.system_fingerprint,
306311
)
307312
yield chunk_response
308313
self.start_pos += y.size(0)
309314
idx += 1
310315

311316
# Yield an ending chunk indicating the generation has completed.
312-
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
317+
end_chunk = CompletionChoiceChunk(
318+
ChunkDelta(None, None, None), idx, finish_reason="stop"
319+
)
313320

314321
yield CompletionResponseChunk(
315322
id=str(id),
316323
choices=[end_chunk],
317324
created=int(time.time()),
318325
model=completion_request.model,
319-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
326+
system_fingerprint=self.system_fingerprint,
327+
)
328+
329+
def sync_completion(self, request: CompletionRequest):
330+
"""Handle a chat completion request and yield a single, non-chunked response"""
331+
output = ""
332+
for chunk in self.chunked_completion(request):
333+
if not chunk.choices[0].finish_reason:
334+
output += chunk.choices[0].delta.content
335+
336+
message = AssistantMessage(content=output)
337+
return CompletionResponse(
338+
id=str(uuid.uuid4()),
339+
choices=[
340+
CompletionChoice(
341+
finish_reason="stop",
342+
index=0,
343+
message=message,
344+
)
345+
],
346+
created=int(time.time()),
347+
model=request.model,
348+
system_fingerprint=self.system_fingerprint,
320349
)
321350

322351
def _callback(self, x, *, buffer, done_generating):

server.py

+48-32
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,35 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
7+
import json
8+
9+
from dataclasses import asdict
10+
11+
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
812

913
from build.builder import BuilderArgs, TokenizerArgs
10-
from flask import Flask, jsonify, request, Response
14+
from flask import Flask, request, Response
1115
from generate import GeneratorArgs
1216

17+
18+
"""
19+
Creates a flask app that can be used to serve the model as a chat API.
20+
"""
1321
app = Flask(__name__)
1422
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
1523
messages: list = []
1624
gen: OpenAiApiGenerator = None
1725

1826

27+
def _del_none(d: dict):
28+
"""Recursively delete None values from a dictionary."""
29+
if type(d) is dict:
30+
return {k: _del_none(v) for k, v in d.items() if v}
31+
elif type(d) is list:
32+
return [_del_none(v) for v in d if v]
33+
return d
34+
35+
1936
@app.route("/chat", methods=["POST"])
2037
def chat_endpoint():
2138
"""
@@ -26,45 +43,44 @@ def chat_endpoint():
2643
2744
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
2845
46+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
47+
a single CompletionResponse object will be returned.
2948
"""
30-
data = request.get_json()
31-
32-
# Add user message to chat history
33-
messages.append(data["messages"][-1])
34-
prompt = messages[-1]["content"]
3549

36-
# Generate the assistant response
37-
req = CompletionRequest(
38-
model=gen.builder_args.checkpoint_path,
39-
prompt=prompt,
40-
temperature=0,
41-
messages=[],
42-
)
50+
print(" === Completion Request ===")
4351

44-
response = ""
52+
# Parse the request in to a CompletionRequest object
53+
data = request.get_json()
54+
req = CompletionRequest(**data)
4555

46-
def unwrap(completion_generator):
47-
token_count = 0
48-
for chunk_response in completion_generator:
49-
content = chunk_response.choices[0].delta.content
50-
if not gen.is_llama3_model or content not in set(
51-
gen.tokenizer.special_tokens.keys()
52-
):
53-
yield content if content is not None else ""
54-
if content == gen.tokenizer.eos_id():
55-
yield "."
56-
token_count += 1
56+
# Add the user message to our internal message history.
57+
messages.append(UserMessage(**req.messages[-1]))
5758

5859
if data.get("stream") == "true":
59-
return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")
60+
61+
def chunk_processor(chunked_completion_generator):
62+
"""Inline function for postprocessing CompletionResponseChunk objects.
63+
64+
Here, we just jsonify the chunk and yield it as a string.
65+
"""
66+
messages.append(AssistantMessage(content=""))
67+
for chunk in chunked_completion_generator:
68+
nextok = chunk.choices[0].delta.content
69+
nextok = nextok if nextok is not None else ""
70+
messages[-1].content += nextok
71+
print(nextok, end="")
72+
yield json.dumps(_del_none(asdict(chunk)))
73+
74+
return Response(
75+
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
76+
)
6077
else:
61-
for content in unwrap(gen.completion(req)):
62-
response += content
78+
response = gen.sync_completion(req)
6379

64-
# Add assistant response to chat history
65-
messages.append(AssistantMessage(content=response))
80+
messages.append(response.choices[0].message)
81+
print(messages[-1].content)
6682

67-
return jsonify({"response": response})
83+
return json.dumps(_del_none(asdict(response)))
6884

6985

7086
def initialize_generator(args) -> OpenAiApiGenerator:

0 commit comments

Comments
 (0)