Skip to content

Commit 6401f55

Browse files
committed
JSON formatted response using OpenAI API types for server completion requests
1 parent 6303c8c commit 6401f55

File tree

3 files changed

+123
-64
lines changed

3 files changed

+123
-64
lines changed

api/api.py

+69-31
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from abc import ABC
1010
from dataclasses import dataclass
11-
from typing import Any, Dict, List, Optional
11+
from typing import Any, Dict, List, Optional, Union
1212

1313
from build.utils import device_sync
1414

@@ -86,6 +86,9 @@ class StreamOptions:
8686

8787
include_usage: bool = False
8888

89+
@dataclass
90+
class ResponseFormat:
91+
type: Optional[str] = None
8992

9093
@dataclass
9194
class CompletionRequest:
@@ -94,25 +97,27 @@ class CompletionRequest:
9497
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
9598
"""
9699

100+
messages: List[_AbstractMessage]
97101
model: str
98-
prompt: str
99-
messages: Optional[List[_AbstractMessage]]
100-
frequency_penalty: float = 0.0
101-
temperature: float = 0.0
102-
stop: Optional[List[str]] = None
103-
stream: bool = False
104-
stream_options: Optional[StreamOptions] = None
105-
echo: bool = False
106-
frequency_penalty: float = 0.0
107-
guided_decode_json_schema: str = None
108-
guided_decode_json_schema_path: str = None
102+
frequency_penalty: float = 0.0 # unimplemented
103+
logit_bias: Optional[Dict[str, float]] = None # unimplemented
104+
logprobs: Optional[bool] = None # unimplemented
105+
top_logprobs: Optional[int] = None # unimplemented
106+
max_tokens: Optional[int] = None # unimplemented
109107
n: int = 1
110-
presence_penalty: float = 0
111-
logit_bias: Optional[Dict[str, float]] = None
112-
logprobs: Optional[bool] = None
113-
top_logprobs: Optional[int] = None
114-
max_tokens: Optional[int] = None
115-
108+
presence_penalty: float = 0 # unimplemented
109+
response_format: Optional[ResponseFormat] = None # unimplemented
110+
seed: Optional[int] = None # unimplemented
111+
service_tier: Optional[str] = None # unimplemented
112+
stop: Optional[List[str]] = None # unimplemented
113+
stream: bool = False
114+
stream_options: Optional[StreamOptions] = None # unimplemented
115+
temperature: Optional[float] = 1.0 # unimplemented
116+
top_p: Optional[float] = 1.0 # unimplemented
117+
tools: Optional[List[Any]] = None # unimplemented
118+
tool_choice: Optional[Union[str, Any]] = None # unimplemented
119+
parallel_tool_calls: Optional[bool] = None # unimplemented
120+
user: Optional[str] = None # unimplemented
116121

117122
@dataclass
118123
class CompletionChoice:
@@ -121,10 +126,10 @@ class CompletionChoice:
121126
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122127
"""
123128

124-
finish_reason: str
125129
index: int
126130
message: AssistantMessage
127-
logprobs: Optional[List[Any]]
131+
finish_reason: str = None
132+
logprobs: Optional[List[Any]] = None
128133

129134

130135
@dataclass
@@ -150,10 +155,10 @@ class CompletionResponse:
150155
choices: List[CompletionChoice]
151156
created: int
152157
model: str
153-
system_fingerprint: str
154-
usage: UsageStats
155-
object: str = "chat.completion"
158+
system_fingerprint: str
156159
service_tier: Optional[str] = None
160+
usage: Optional[UsageStats] = None
161+
object: str = "chat.completion"
157162

158163

159164
@dataclass
@@ -193,8 +198,8 @@ class CompletionResponseChunk:
193198
created: int
194199
model: str
195200
system_fingerprint: str
196-
object: str = "chat.completion.chunk"
197201
service_tier: Optional[str] = None
202+
object: str = "chat.completion.chunk"
198203
usage: Optional[UsageStats] = None
199204

200205

@@ -220,8 +225,13 @@ def __init__(self, *args, **kwargs):
220225
if self.draft_model is not None
221226
else self.model.config.max_seq_length
222227
)
228+
# The System fingerprint is a unique identifier for the model and its configuration.
229+
# Currently, this is not implemented in a
230+
self.system_fingerprint = (
231+
self.builder_args.device + type(self.builder_args.precision).__name__
232+
)
223233

224-
def completion(self, completion_request: CompletionRequest):
234+
def chunked_completion(self, completion_request: CompletionRequest):
225235
"""Handle a chat completion request and yield a chunked response.
226236
227237
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +240,8 @@ def completion(self, completion_request: CompletionRequest):
230240
- messages: The server consumes the final element of the array as the prompt.
231241
- model: This has no impact on the server state, i.e. changing the model in the request
232242
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233-
- temperature: This is used to control the randomness of the response. The server will use the temperature
243+
- temperature: This is used to control the randomness of the response.
244+
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234245
235246
See https://github.com/pytorch/torchchat/issues/973 for more details.
236247
@@ -246,13 +257,16 @@ def completion(self, completion_request: CompletionRequest):
246257

247258
# Initialize counters for chunk responses and encode the prompt.
248259
id = str(uuid.uuid4())
260+
249261
idx = 0
250262
buffer = []
251263
encoded = self.encode_tokens(
252-
completion_request.prompt, bos=True, device=self.builder_args.device
264+
completion_request.messages[-1].get("content"),
265+
bos=True,
266+
device=self.builder_args.device,
253267
)
254268
generator_args = GeneratorArgs(
255-
completion_request.prompt,
269+
completion_request.messages[-1].get("content"),
256270
encoded_prompt=encoded,
257271
chat_mode=False,
258272
)
@@ -302,21 +316,45 @@ def callback(x, *, done_generating=False):
302316
choices=[choice_chunk],
303317
created=int(time.time()),
304318
model=completion_request.model,
305-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
319+
system_fingerprint=self.system_fingerprint,
306320
)
307321
yield chunk_response
308322
self.start_pos += y.size(0)
309323
idx += 1
310324

311325
# Yield an ending chunk indicating the generation has completed.
312-
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
326+
end_chunk = CompletionChoiceChunk(
327+
ChunkDelta(None, None, None), idx, finish_reason="stop"
328+
)
313329

314330
yield CompletionResponseChunk(
315331
id=str(id),
316332
choices=[end_chunk],
317333
created=int(time.time()),
318334
model=completion_request.model,
319-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
335+
system_fingerprint=self.system_fingerprint,
336+
)
337+
338+
def sync_completion(self, request: CompletionRequest):
339+
"""Handle a chat completion request and yield a single, non-chunked response"""
340+
output = ""
341+
for chunk in self.chunked_completion(request):
342+
if not chunk.choices[0].finish_reason:
343+
output += chunk.choices[0].delta.content
344+
345+
message = AssistantMessage(content=output)
346+
return CompletionResponse(
347+
id=str(uuid.uuid4()),
348+
choices=[
349+
CompletionChoice(
350+
finish_reason="stop",
351+
index=0,
352+
message=message,
353+
)
354+
],
355+
created=int(time.time()),
356+
model=request.model,
357+
system_fingerprint=self.system_fingerprint,
320358
)
321359

322360
def _callback(self, x, *, buffer, done_generating):

generate.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,18 @@ def generate(
447447
start_pos: int = 0,
448448
draft_model: Transformer,
449449
speculate_k: Optional[int] = 8,
450-
sequential_prefill=True,
450+
sequential_prefill=True,
451451
callback=lambda x: x,
452452
max_seq_length: int,
453+
seed: Optional[int] = None,
453454
**sampling_kwargs,
454455
) -> torch.Tensor:
455456
"""
456457
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
457458
"""
459+
if seed:
460+
torch.manual_seed(seed)
461+
458462
is_speculative = draft_model is not None
459463
device, dtype = prompt.device, prompt.dtype
460464

server.py

+49-32
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,36 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
7+
import json
8+
9+
from dataclasses import asdict
10+
from typing import Dict, List, Union
11+
12+
from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
813

914
from build.builder import BuilderArgs, TokenizerArgs
10-
from flask import Flask, jsonify, request, Response
15+
from flask import Flask, request, Response
1116
from generate import GeneratorArgs
1217

18+
19+
"""
20+
Creates a flask app that can be used to serve the model as a chat API.
21+
"""
1322
app = Flask(__name__)
1423
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
1524
messages: list = []
1625
gen: OpenAiApiGenerator = None
1726

1827

28+
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
29+
"""Recursively delete None values from a dictionary."""
30+
if type(d) is dict:
31+
return {k: _del_none(v) for k, v in d.items() if v}
32+
elif type(d) is list:
33+
return [_del_none(v) for v in d if v]
34+
return d
35+
36+
1937
@app.route("/chat", methods=["POST"])
2038
def chat_endpoint():
2139
"""
@@ -26,45 +44,44 @@ def chat_endpoint():
2644
2745
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
2846
47+
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48+
a single CompletionResponse object will be returned.
2949
"""
30-
data = request.get_json()
31-
32-
# Add user message to chat history
33-
messages.append(data["messages"][-1])
34-
prompt = messages[-1]["content"]
3550

36-
# Generate the assistant response
37-
req = CompletionRequest(
38-
model=gen.builder_args.checkpoint_path,
39-
prompt=prompt,
40-
temperature=0,
41-
messages=[],
42-
)
51+
print(" === Completion Request ===")
4352

44-
response = ""
53+
# Parse the request in to a CompletionRequest object
54+
data = request.get_json()
55+
req = CompletionRequest(**data)
4556

46-
def unwrap(completion_generator):
47-
token_count = 0
48-
for chunk_response in completion_generator:
49-
content = chunk_response.choices[0].delta.content
50-
if not gen.is_llama3_model or content not in set(
51-
gen.tokenizer.special_tokens.keys()
52-
):
53-
yield content if content is not None else ""
54-
if content == gen.tokenizer.eos_id():
55-
yield "."
56-
token_count += 1
57+
# Add the user message to our internal message history.
58+
messages.append(UserMessage(**req.messages[-1]))
5759

5860
if data.get("stream") == "true":
59-
return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")
61+
62+
def chunk_processor(chunked_completion_generator):
63+
"""Inline function for postprocessing CompletionResponseChunk objects.
64+
65+
Here, we just jsonify the chunk and yield it as a string.
66+
"""
67+
messages.append(AssistantMessage(content=""))
68+
for chunk in chunked_completion_generator:
69+
if (next_tok := chunk.choices[0].delta.content) is None:
70+
next_tok = ""
71+
messages[-1].content += next_tok
72+
print(next_tok, end="")
73+
yield json.dumps(_del_none(asdict(chunk)))
74+
75+
return Response(
76+
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
77+
)
6078
else:
61-
for content in unwrap(gen.completion(req)):
62-
response += content
79+
response = gen.sync_completion(req)
6380

64-
# Add assistant response to chat history
65-
messages.append(AssistantMessage(content=response))
81+
messages.append(response.choices[0].message)
82+
print(messages[-1].content)
6683

67-
return jsonify({"response": response})
84+
return json.dumps(_del_none(asdict(response)))
6885

6986

7087
def initialize_generator(args) -> OpenAiApiGenerator:

0 commit comments

Comments
 (0)