@@ -21,47 +21,114 @@ class chat_template {
21
21
public:
22
22
23
23
private:
24
- bool _supports_tools = true ;
24
+ bool supports_tools_ = true ;
25
25
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
26
26
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
27
- bool _requires_object_arguments = false ;
28
- bool _supports_system_role = true ;
29
- std::string _source;
30
- std::string _bos_token;
31
- std::string _eos_token;
32
- std::shared_ptr<minja::TemplateNode> _template_root;
27
+ bool requires_object_arguments_ = false ;
28
+ bool supports_system_role_ = true ;
29
+ bool supports_parallel_tool_calls_ = false ;
30
+ std::string source_;
31
+ std::string bos_token_;
32
+ std::string eos_token_;
33
+ std::shared_ptr<minja::TemplateNode> template_root_;
34
+
35
+ std::string try_render (
36
+ const nlohmann::ordered_json & messages,
37
+ const nlohmann::ordered_json & tools,
38
+ bool add_generation_prompt,
39
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
40
+ {
41
+ try {
42
+ auto prompt = apply (messages, tools, add_generation_prompt, extra_context);
43
+ // fprintf(stderr, "Prompt: %s\n", prompt.c_str());
44
+ return prompt;
45
+ } catch (const std::exception & e) {
46
+ // fprintf(stderr, "Error: %s\n", e.what());
47
+ return " " ;
48
+ }
49
+ }
33
50
34
51
public:
35
52
chat_template (const std::string & source, const std::string & bos_token, const std::string & eos_token)
36
- : _source (source), _bos_token (bos_token), _eos_token (eos_token)
53
+ : source_ (source), bos_token_ (bos_token), eos_token_ (eos_token)
37
54
{
38
- _supports_tools = source.find (" tools" ) != std::string::npos;
39
- _requires_object_arguments =
40
- source.find (" tool_call.arguments | items" ) != std::string::npos
41
- || source.find (" tool_call.arguments | tojson" ) != std::string::npos;
42
- _supports_system_role = source.find (" System role not supported" ) == std::string::npos;
43
-
44
- _template_root = minja::Parser::parse (_source, {
55
+ template_root_ = minja::Parser::parse (source_, {
45
56
/* .trim_blocks = */ true ,
46
57
/* .lstrip_blocks = */ true ,
47
58
/* .keep_trailing_newline = */ false ,
48
59
});
60
+ supports_tools_ = source.find (" tools" ) != std::string::npos;
61
+
62
+ auto renders_string_arguments =
63
+ try_render ({
64
+ {
65
+ {" role" , " user" },
66
+ {" content" , " Hey" }
67
+ },
68
+ {
69
+ {" role" , " assistant" },
70
+ {" tool_calls" , json::array ({
71
+ {
72
+ {" id" , " call_1___" },
73
+ {" type" , " function" },
74
+ {" function" , {
75
+ {" arguments" , " {\" code\" : \" print('Hello, World!')\" }" },
76
+ {" name" , " ipython" },
77
+ }},
78
+ },
79
+ })},
80
+ }
81
+ }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
82
+ if (!renders_string_arguments) {
83
+ auto renders_object_arguments =
84
+ try_render ({
85
+ {
86
+ {" role" , " user" },
87
+ {" content" , " Hey" }
88
+ },
89
+ {
90
+ {" role" , " assistant" },
91
+ {" tool_calls" , json::array ({
92
+ {
93
+ {" id" , " call_1___" },
94
+ {" type" , " function" },
95
+ {" function" , {
96
+ {" arguments" , {
97
+ {" code" , " print('Hello, World!')" },
98
+ }},
99
+ {" name" , " ipython" },
100
+ }},
101
+ },
102
+ })},
103
+ }
104
+ }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos;
105
+ requires_object_arguments_ = renders_object_arguments;
106
+ }
107
+ supports_parallel_tool_calls_ = source.find (" tool_call_id" ) != std::string::npos;
108
+
109
+ supports_system_role_ = try_render ({
110
+ {{" role" , " system" }, {" content" , " <System Needle>" }},
111
+ {{" role" , " user" }, {" content" , " Hey" }}
112
+ }, {}, false ).find (" <System Needle>" ) != std::string::npos;
49
113
}
50
114
51
- const std::string & source () const { return _source; }
52
- bool supports_tools () const { return _supports_tools; }
115
+ const std::string & source () const { return source_; }
116
+ bool supports_tools () const { return supports_tools_; }
117
+ bool supports_parallel_tool_calls () const { return supports_parallel_tool_calls_; }
53
118
54
119
std::string apply (
55
120
const nlohmann::ordered_json & messages,
56
121
const nlohmann::ordered_json & tools,
57
122
bool add_generation_prompt,
58
123
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
59
124
{
60
- auto actual_messages = messages ;
125
+ json actual_messages;
61
126
62
127
// First, "fix" messages so they have a chance to be rendered correctly by the template
63
128
64
- if (_requires_object_arguments || !_supports_system_role) {
129
+ if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
130
+ actual_messages = json::array ();
131
+
65
132
std::string pending_system;
66
133
auto flush_sys = [&]() {
67
134
if (!pending_system.empty ()) {
@@ -72,13 +139,66 @@ class chat_template {
72
139
pending_system.clear ();
73
140
}
74
141
};
75
- for (auto & message : actual_messages) {
142
+ for (const auto & message_ : messages) {
143
+ auto message = message_;
76
144
if (!message.contains (" role" ) || !message.contains (" content" )) {
77
145
throw std::runtime_error (" message must have 'role' and 'content' fields: " + message.dump ());
78
146
}
79
147
std::string role = message.at (" role" );
80
148
81
- if (!message[" content" ].is_null () && !_supports_system_role) {
149
+ if (message.contains (" tool_calls" )) {
150
+ if (requires_object_arguments_ || !supports_tools_) {
151
+ for (auto & tool_call : message.at (" tool_calls" )) {
152
+ if (tool_call[" type" ] == " function" ) {
153
+ auto & function = tool_call.at (" function" );
154
+ std::string arguments = function.at (" arguments" );
155
+ function[" arguments" ] = json::parse (arguments);
156
+ }
157
+ }
158
+ }
159
+ if (!supports_tools_) {
160
+ auto content = message.at (" content" );
161
+ auto tool_calls = json::array ();
162
+ for (const auto & tool_call : message.at (" tool_calls" )) {
163
+ if (tool_call.at (" type" ) != " function" ) {
164
+ continue ;
165
+ }
166
+ const auto & function = tool_call.at (" function" );
167
+ auto tc = json {
168
+ {" name" , function.at (" name" )},
169
+ {" arguments" , function.at (" arguments" )},
170
+ };
171
+ if (tool_call.contains (" id" )) {
172
+ tc[" id" ] = tool_call[" id" ];
173
+ }
174
+ tool_calls.push_back (tc);
175
+ }
176
+ auto obj = json {
177
+ {" tool_calls" , tool_calls},
178
+ };
179
+ if (!content.is_null () && content != " " ) {
180
+ obj[" content" ] = content;
181
+ }
182
+ message[" content" ] = obj.dump (2 );
183
+ message.erase (" tool_calls" );
184
+ }
185
+ }
186
+ if (!supports_tools_ && role == " tool" ) {
187
+ message[" role" ] = " user" ;
188
+ auto obj = json {
189
+ {" tool_response" , {
190
+ {" tool" , message.at (" name" )},
191
+ {" content" , message.at (" content" )},
192
+ }},
193
+ };
194
+ if (message.contains (" tool_call_id" )) {
195
+ obj[" tool_response" ][" tool_call_id" ] = message.at (" tool_call_id" );
196
+ }
197
+ message[" content" ] = obj.dump (2 );
198
+ message.erase (" name" );
199
+ }
200
+
201
+ if (!message[" content" ].is_null () && !supports_system_role_) {
82
202
std::string content = message.at (" content" );
83
203
if (role == " system" ) {
84
204
if (!pending_system.empty ()) pending_system += " \n " ;
@@ -95,24 +215,18 @@ class chat_template {
95
215
}
96
216
}
97
217
}
98
- if (_requires_object_arguments && message.contains (" tool_calls" )) {
99
- for (auto & tool_call : message.at (" tool_calls" )) {
100
- if (tool_call[" type" ] == " function" ) {
101
- auto & function = tool_call.at (" function" );
102
- std::string arguments = function.at (" arguments" );
103
- function[" arguments" ] = json::parse (arguments);
104
- }
105
- }
106
- }
218
+ actual_messages.push_back (message);
107
219
}
108
220
flush_sys ();
221
+ } else {
222
+ actual_messages = messages;
109
223
}
110
224
111
225
auto context = minja::Context::make (json ({
112
226
{" messages" , actual_messages},
113
227
{" add_generation_prompt" , add_generation_prompt},
114
- {" bos_token" , _bos_token },
115
- {" eos_token" , _eos_token },
228
+ {" bos_token" , bos_token_ },
229
+ {" eos_token" , eos_token_ },
116
230
}));
117
231
118
232
if (!tools.is_null ()) {
@@ -126,8 +240,8 @@ class chat_template {
126
240
}
127
241
}
128
242
129
- return _template_root ->render (context);
243
+ return template_root_ ->render (context);
130
244
}
131
245
};
132
246
133
- } // namespace minja
247
+ } // namespace minja
0 commit comments