Skip to content

Commit 4e26b22

Browse files
committed
JSON formatted response using OpenAI API types for server completion requests
1 parent 6303c8c commit 4e26b22

File tree

3 files changed

+123
-61
lines changed

3 files changed

+123
-61
lines changed

api/api.py

+70-29
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from abc import ABC
1010
from dataclasses import dataclass
11-
from typing import Any, Dict, List, Optional
11+
from typing import Any, Dict, List, Optional, Union
1212

1313
from build.utils import device_sync
1414

@@ -87,31 +87,39 @@ class StreamOptions:
8787
include_usage: bool = False
8888

8989

90+
@dataclass
91+
class ResponseFormat:
92+
type: Optional[str] = None
93+
94+
9095
@dataclass
9196
class CompletionRequest:
9297
"""A full chat completion request.
9398
9499
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
95100
"""
96101

102+
messages: List[_AbstractMessage]
97103
model: str
98-
prompt: str
99-
messages: Optional[List[_AbstractMessage]]
100-
frequency_penalty: float = 0.0
101-
temperature: float = 0.0
102-
stop: Optional[List[str]] = None
103-
stream: bool = False
104-
stream_options: Optional[StreamOptions] = None
105-
echo: bool = False
106-
frequency_penalty: float = 0.0
107-
guided_decode_json_schema: str = None
108-
guided_decode_json_schema_path: str = None
104+
frequency_penalty: float = 0.0 # unimplemented
105+
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106+
logprobs: Optional[bool] = None # unimplemented
107+
top_logprobs: Optional[int] = None # unimplemented
108+
max_tokens: Optional[int] = None # unimplemented
109109
n: int = 1
110-
presence_penalty: float = 0
111-
logit_bias: Optional[Dict[str, float]] = None
112-
logprobs: Optional[bool] = None
113-
top_logprobs: Optional[int] = None
114-
max_tokens: Optional[int] = None
110+
presence_penalty: float = 0 # unimplemented
111+
response_format: Optional[ResponseFormat] = None # unimplemented
112+
seed: Optional[int] = None # unimplemented
113+
service_tier: Optional[str] = None # unimplemented
114+
stop: Optional[List[str]] = None # unimplemented
115+
stream: bool = False
116+
stream_options: Optional[StreamOptions] = None # unimplemented
117+
temperature: Optional[float] = 1.0 # unimplemented
118+
top_p: Optional[float] = 1.0 # unimplemented
119+
tools: Optional[List[Any]] = None # unimplemented
120+
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121+
parallel_tool_calls: Optional[bool] = None # unimplemented
122+
user: Optional[str] = None # unimplemented
115123

116124

117125
@dataclass
@@ -121,10 +129,10 @@ class CompletionChoice:
121129
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122130
"""
123131

124-
finish_reason: str
125132
index: int
126133
message: AssistantMessage
127-
logprobs: Optional[List[Any]]
134+
finish_reason: str = None
135+
logprobs: Optional[List[Any]] = None
128136

129137

130138
@dataclass
@@ -151,9 +159,9 @@ class CompletionResponse:
151159
created: int
152160
model: str
153161
system_fingerprint: str
154-
usage: UsageStats
155-
object: str = "chat.completion"
156162
service_tier: Optional[str] = None
163+
usage: Optional[UsageStats] = None
164+
object: str = "chat.completion"
157165

158166

159167
@dataclass
@@ -193,8 +201,8 @@ class CompletionResponseChunk:
193201
created: int
194202
model: str
195203
system_fingerprint: str
196-
object: str = "chat.completion.chunk"
197204
service_tier: Optional[str] = None
205+
object: str = "chat.completion.chunk"
198206
usage: Optional[UsageStats] = None
199207

200208

@@ -220,8 +228,13 @@ def __init__(self, *args, **kwargs):
220228
if self.draft_model is not None
221229
else self.model.config.max_seq_length
222230
)
231+
# The System fingerprint is a unique identifier for the model and its configuration.
232+
# Currently, this is not implemented in a
233+
self.system_fingerprint = (
234+
self.builder_args.device + type(self.builder_args.precision).__name__
235+
)
223236

224-
def completion(self, completion_request: CompletionRequest):
237+
def chunked_completion(self, completion_request: CompletionRequest):
225238
"""Handle a chat completion request and yield a chunked response.
226239
227240
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +243,8 @@ def completion(self, completion_request: CompletionRequest):
230243
- messages: The server consumes the final element of the array as the prompt.
231244
- model: This has no impact on the server state, i.e. changing the model in the request
232245
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233-
- temperature: This is used to control the randomness of the response. The server will use the temperature
246+
- temperature: This is used to control the randomness of the response.
247+
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234248
235249
See https://github.com/pytorch/torchchat/issues/973 for more details.
236250
@@ -246,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):
246260

247261
# Initialize counters for chunk responses and encode the prompt.
248262
id = str(uuid.uuid4())
263+
249264
idx = 0
250265
buffer = []
251266
encoded = self.encode_tokens(
252-
completion_request.prompt, bos=True, device=self.builder_args.device
267+
completion_request.messages[-1].get("content"),
268+
bos=True,
269+
device=self.builder_args.device,
253270
)
254271
generator_args = GeneratorArgs(
255-
completion_request.prompt,
272+
completion_request.messages[-1].get("content"),
256273
encoded_prompt=encoded,
257274
chat_mode=False,
258275
)
@@ -302,21 +319,45 @@ def callback(x, *, done_generating=False):
302319
choices=[choice_chunk],
303320
created=int(time.time()),
304321
model=completion_request.model,
305-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
322+
system_fingerprint=self.system_fingerprint,
306323
)
307324
yield chunk_response
308325
self.start_pos += y.size(0)
309326
idx += 1
310327

311328
# Yield an ending chunk indicating the generation has completed.
312-
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
329+
end_chunk = CompletionChoiceChunk(
330+
ChunkDelta(None, None, None), idx, finish_reason="stop"
331+
)
313332

314333
yield CompletionResponseChunk(
315334
id=str(id),
316335
choices=[end_chunk],
317336
created=int(time.time()),
318337
model=completion_request.model,
319-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
338+
system_fingerprint=self.system_fingerprint,
339+
)
340+
341+
def sync_completion(self, request: CompletionRequest):
342+
"""Handle a chat completion request and yield a single, non-chunked response"""
343+
output = ""
344+
for chunk in self.chunked_completion(request):
345+
if not chunk.choices[0].finish_reason:
346+
output += chunk.choices[0].delta.content
347+
348+
message = AssistantMessage(content=output)
349+
return CompletionResponse(
350+
id=str(uuid.uuid4()),
351+
choices=[
352+
CompletionChoice(
353+
finish_reason="stop",
354+
index=0,
355+
message=message,
356+
)
357+
],
358+
created=int(time.time()),
359+
model=request.model,
360+
system_fingerprint=self.system_fingerprint,
320361
)
321362

322363
def _callback(self, x, *, buffer, done_generating):

generate.py

+4
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,15 @@ def generate(
450450
sequential_prefill=True,
451451
callback=lambda x: x,
452452
max_seq_length: int,
453+
seed: Optional[int] = None,
453454
**sampling_kwargs,
454455
) -> torch.Tensor:
455456
"""
456457
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
457458
"""
459+
if seed:
460+
torch.manual_seed(seed)
461+
458462
is_speculative = draft_model is not None
459463
device, dtype = prompt.device, prompt.dtype
460464

server.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,36 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
7+
import json
8+
9+
from dataclasses import asdict
10+
from typing import Dict, List, Union
11+
12+
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
813

914
from build.builder import BuilderArgs, TokenizerArgs
10-
from flask import Flask, jsonify, request, Response
15+
from flask import Flask, request, Response
1116
from generate import GeneratorArgs
1217

18+
19+
"""
20+
Creates a flask app that can be used to serve the model as a chat API.
21+
"""
1322
app = Flask(__name__)
1423
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
1524
messages: list = []
1625
gen: OpenAiApiGenerator = None
1726

1827

28+
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29+
"""Recursively delete None values from a dictionary."""
30+
if type(d) is dict:
31+
return {k: _del_none(v) for k, v in d.items() if v}
32+
elif type(d) is list:
33+
return [_del_none(v) for v in d if v]
34+
return d
35+
36+
1937
@app.route("/chat", methods=["POST"])
2038
def chat_endpoint():
2139
"""
@@ -26,45 +44,44 @@ def chat_endpoint():
2644
2745
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
2846
47+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48+
a single CompletionResponse object will be returned.
2949
"""
30-
data = request.get_json()
31-
32-
# Add user message to chat history
33-
messages.append(data["messages"][-1])
34-
prompt = messages[-1]["content"]
3550

36-
# Generate the assistant response
37-
req = CompletionRequest(
38-
model=gen.builder_args.checkpoint_path,
39-
prompt=prompt,
40-
temperature=0,
41-
messages=[],
42-
)
51+
print(" === Completion Request ===")
4352

44-
response = ""
53+
# Parse the request in to a CompletionRequest object
54+
data = request.get_json()
55+
req = CompletionRequest(**data)
4556

46-
def unwrap(completion_generator):
47-
token_count = 0
48-
for chunk_response in completion_generator:
49-
content = chunk_response.choices[0].delta.content
50-
if not gen.is_llama3_model or content not in set(
51-
gen.tokenizer.special_tokens.keys()
52-
):
53-
yield content if content is not None else ""
54-
if content == gen.tokenizer.eos_id():
55-
yield "."
56-
token_count += 1
57+
# Add the user message to our internal message history.
58+
messages.append(UserMessage(**req.messages[-1]))
5759

5860
if data.get("stream") == "true":
59-
return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")
61+
62+
def chunk_processor(chunked_completion_generator):
63+
"""Inline function for postprocessing CompletionResponseChunk objects.
64+
65+
Here, we just jsonify the chunk and yield it as a string.
66+
"""
67+
messages.append(AssistantMessage(content=""))
68+
for chunk in chunked_completion_generator:
69+
if (next_tok := chunk.choices[0].delta.content) is None:
70+
next_tok = ""
71+
messages[-1].content += next_tok
72+
print(next_tok, end="")
73+
yield json.dumps(_del_none(asdict(chunk)))
74+
75+
return Response(
76+
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
77+
)
6078
else:
61-
for content in unwrap(gen.completion(req)):
62-
response += content
79+
response = gen.sync_completion(req)
6380

64-
# Add assistant response to chat history
65-
messages.append(AssistantMessage(content=response))
81+
messages.append(response.choices[0].message)
82+
print(messages[-1].content)
6683

67-
return jsonify({"response": response})
84+
return json.dumps(_del_none(asdict(response)))
6885

6986

7087
def initialize_generator(args) -> OpenAiApiGenerator:

0 commit comments

Comments
 (0)