forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_chat_handlers.py
235 lines (196 loc) · 9.73 KB
/
test_chat_handlers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#
#
# python -m examples.openai.test_chat_handlers | tee examples/openai/test_chat_handlers.md
import json
import sys
from examples.openai.api import FunctionCall, Message, Tool, ToolCall, ToolFunction
from examples.openai.prompting import ChatHandlerArgs, ChatTemplate, ToolsPromptStyle, get_chat_handler
TEST_ARG_A = 2535
TEST_ARG_B = 32222000403
TEST_SUM = 32222002938
QUESTION = "Add two numbers for the purpose of this test."
ANSWER = "The sum of 2535 and 32222000403 is 42."
PROMPT_MESSAGE = Message(
role="user",
content=QUESTION,
)
ASSIST_MESSAGE = Message(
role="assistant",
content=ANSWER,
)
TOOL_NAME = "superSecretTool"
TOOL_CALL = ToolCall(
id="call_531873",
type="function",
function=FunctionCall(
name=TOOL_NAME,
arguments={
"a": TEST_ARG_A,
"b": TEST_ARG_B
}
)
)
TOOL_CALL_MESSAGE = Message(
role="assistant",
content=None,
tool_calls=[TOOL_CALL],
)
TEST_THOUGHT = "I've thought a lot about this."
THOUGHTFUL_TOOL_CALL_MESSAGE = Message(
role="assistant",
content=TEST_THOUGHT,
tool_calls=[TOOL_CALL],
)
# UNDERSCORE_ESCAPED_TOOL_CALL_MESSAGE = Message(**{
# **TOOL_CALL_MESSAGE.model_dump(),
# "tool_calls": [
# json.loads(tc.model_dump_json().replace("_", "\\_"))
# for tc in TOOL_CALL_MESSAGE.tool_calls
# ],
# })
TOOL_MESSAGE = Message(
role="tool",
name=TOOL_NAME,
tool_call_id="call_531873",
content=f'{TEST_SUM}',
tool_calls=None
)
TEST_MESSAGES = [
PROMPT_MESSAGE,
TOOL_CALL_MESSAGE,
TOOL_MESSAGE,
ASSIST_MESSAGE,
]
TEST_MESSAGES_THOUGHT = [
PROMPT_MESSAGE,
THOUGHTFUL_TOOL_CALL_MESSAGE,
TOOL_MESSAGE,
ASSIST_MESSAGE,
]
TEST_TOOLS = [
Tool(
type="function",
function=ToolFunction(
name=TOOL_NAME,
description="Adds two numbers",
parameters={
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"},
},
"required": ["a", "b"]
}
)
),
Tool(
type="function",
function=ToolFunction(
name="say",
description="Says something out loud (TTS)",
parameters={
"properties": {
"text": {
"description": "The text to say out loud",
"type": "string"
},
},
"required": ["text"]
}
)
)
]
TEST_OUTPUT_SCHEMA = {"type": "integer"}
# Generate the JSON for TEST_TEMPLATES below by uncommenting this block:
#
# TEST_TEMPLATES = {
# 'mistral_instruct_v0_1': ChatTemplate.from_huggingface("mistralai/Mixtral-8x7B-Instruct-v0.1"),
# 'functionary_v2_2': ChatTemplate.from_huggingface("meetkai/functionary-small-v2.2"),
# 'hermes_2_pro_mistral': ChatTemplate.from_huggingface("NousResearch/Hermes-2-Pro-Mistral-7B"),
# 'llama2': ChatTemplate.from_huggingface("meta-llama/Llama-2-7b-chat-hf"),
# }
# print(json.dumps({k: v.model_dump() for k, v in TEST_TEMPLATES.items()}, indent=2))
# exit(0)
TEST_TEMPLATES = {
"mistral_instruct_v0_1": {
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"eos_token": "</s>",
"bos_token": "<s>"
},
"functionary_v2_2": {
"template": "{#v2.2#}\n{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == \"no\" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %}",
"eos_token": "</s>",
"bos_token": "<s>"
},
"hermes_2_pro_mistral": {
"template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"eos_token": "<|im_end|>",
"bos_token": "<s>"
},
"llama2": {
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"eos_token": "</s>",
"bos_token": "<s>"
},
}
MODELS_WITH_PARALLEL_CALLS = set(["functionary_v2_2"])
TEST_TEMPLATES = {k: ChatTemplate(**v) for k, v in TEST_TEMPLATES.items()}
if __name__ == "__main__":
failures = []
print(f'\nMessages:\n\n```js\n{json.dumps([m.model_dump() for m in TEST_MESSAGES_THOUGHT], indent=2)}\n```\n')
def check(b: bool, msg: str):
if not b:
sys.stderr.write(f'FAILURE: {msg}\n\n')
failures.append(msg)
functionary_v2_2 = TEST_TEMPLATES["functionary_v2_2"]
check(functionary_v2_2.inferred_tool_style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2, "functionary_v2_2 should be inferred as TYPESCRIPT_FUNCTIONARY_V2")
for model_name, chat_template in TEST_TEMPLATES.items():
check(chat_template.potentially_supports_parallel_calls == (model_name in MODELS_WITH_PARALLEL_CALLS),
f"{model_name} should {'not ' if model_name not in MODELS_WITH_PARALLEL_CALLS else ''} be detected as potentially supporting parallel calls")
argss = {
"with tools": ChatHandlerArgs(
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
response_schema=TEST_OUTPUT_SCHEMA,
tools=TEST_TOOLS,
),
"without tools": ChatHandlerArgs(
chat_template=chat_template, #ChatTemplate.from_gguf(GGUFKeyValues(model)),
response_schema=TEST_OUTPUT_SCHEMA,
tools=[],
),
}
print(f"\n# {model_name}\n")
if chat_template.potentially_supports_parallel_calls:
print("\n**Might Support Parallel Tool Calls**\n")
print(f'\nTemplate:\n\n```js\n{chat_template.template}\n```\n')
for style in ToolsPromptStyle:
if (style == ToolsPromptStyle.TYPESCRIPT_FUNCTIONARY_V2) != (model_name.startswith("functionary")):
continue
if style == ToolsPromptStyle.TOOLS_MIXTRAL and model_name != "mistral_instruct_v0_1":
continue
if model_name == "mistral_instruct_v0_1" and style not in (ToolsPromptStyle.TOOLS_THOUGHTFUL_STEPS, ToolsPromptStyle.TOOLS_MIXTRAL):
continue
print(f'\n## {model_name} / {style.name}\n')
for tool_situation, args in argss.items():
ch = get_chat_handler(args, parallel_calls=True, tool_style=style)
print(f'\n### {model_name} / {style.name} / {tool_situation}\n')
print(f'\nPrompt:\n\n```js\n{ch.render_prompt(TEST_MESSAGES_THOUGHT)}\n```\n')
print(f'\nOutput format prompt:\n\n```json\n{ch.output_format_prompt.content}\n```\n')
print(f'\nGrammar:\n\n```js\n{ch.grammar}\n```\n')
# if model_name == 'hermes_2_pro_mistral':
# print("Skipping hermes_2_pro_mistral")
# continue
def check_finds(msgs, strings_to_find):
prompt = ch.render_prompt(msgs)
for s in strings_to_find:
check(str(s) in prompt, f"Missing {s} in prompt for {model_name}:\n{prompt}")
check_finds([PROMPT_MESSAGE], (QUESTION,))
check_finds([ASSIST_MESSAGE], (ANSWER,))
check_finds([TOOL_CALL_MESSAGE], (TEST_ARG_A, TEST_ARG_B, TOOL_NAME))
check_finds([THOUGHTFUL_TOOL_CALL_MESSAGE], (TEST_THOUGHT, TEST_ARG_A, TEST_ARG_B, TOOL_NAME,))
check_finds([TOOL_MESSAGE], (TEST_SUM,))
if chat_template.potentially_supports_parallel_calls:
check_finds([TOOL_MESSAGE], (TOOL_NAME,))
if failures:
for f in failures:
print(f'{f}\n\n')
assert not failures