forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_chat_completion.py
148 lines (135 loc) · 5.39 KB
/
test_chat_completion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
if truncated:
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "stop"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert "error" in res.body
@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body