Skip to content

Commit 7dc69cc

Browse files
committed
add load_added_tokens; support med-r1.
1 parent 0d526b2 commit 7dc69cc

File tree

7 files changed

+164
-21
lines changed

7 files changed

+164
-21
lines changed

docs/models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
* Baichuan (`BaichuanForCausalLM`, `BaichuanM1ForCausalLM`)
1212
* [x] [Chat-7B](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat), [Chat-13B](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)
1313
* [x] M1: [Instruct-14B](https://huggingface.co/baichuan-inc/Baichuan-M1-14B-Instruct)
14+
* [x] Fine-tunings: [Med-R1](https://modelscope.cn/models/wangrongsheng/Med-R1/files) (Tip: `--set chat_template im`)
1415

1516
* BlueLM (`BlueLMForCausalLM`)
1617
* [x] [Chat-7B](https://huggingface.co/vivo-ai/BlueLM-7B-Chat), [Chat-7B 32K](https://huggingface.co/vivo-ai/BlueLM-7B-Chat-32K)

models/baichuan.cpp

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,21 @@ namespace m1
139139
{
140140
public:
141141
Tokenizer(const Config &config)
142-
: llama::v2::Tokenizer(config, &_chat_encoder)
142+
: llama::v2::Tokenizer(config, &_chat_encoder),
143+
im_end_token_id(-1)
143144
{
144145
sys_prompt = "You are a helpful assistant.";
145146
}
146147

147148
size_t load(tokenizer::DataReader *buffer, int n_vocab) override
148149
{
149150
size_t r = llama::v2::Tokenizer::load(buffer, n_vocab);
151+
152+
int id = tp->PieceToId("<reserved_147>");
153+
if (id >= 0) tp->OverrideTokenDecoding(id, "<think>");
154+
id = tp->PieceToId("<reserved_148>");
155+
if (id >= 0) tp->OverrideTokenDecoding(id, "</think>");
156+
150157
b_sys_token_id = 71;
151158
b_usys_token_id = 72;
152159
c_q_token_id = 73;
@@ -163,13 +170,34 @@ namespace m1
163170
llama::v2::Tokenizer::encode(text, ids);
164171
}
165172

173+
bool load_config(const json::JSON &config) override
174+
{
175+
load_added_tokens(config, {
176+
{"<B_SYS>", &b_sys_token_id},
177+
{"<B_USYS>", &b_usys_token_id},
178+
{"<C_Q>", &c_q_token_id},
179+
{"<C_A>", &c_a_token_id},
180+
{"<B_FUNC>", &b_func_token_id},
181+
{"<B_CODE>", &b_code_token_id},
182+
{"<|im_start|>", &im_start_token_id},
183+
{"<|im_end|>", &im_end_token_id},
184+
});
185+
186+
if (im_end_token_id >= 0)
187+
terminate_ids.insert(im_end_token_id);
188+
189+
return true;
190+
}
191+
166192
public:
167193
int b_sys_token_id;
168194
int b_usys_token_id;
169195
int c_q_token_id;
170196
int c_a_token_id;
171197
int b_func_token_id;
172198
int b_code_token_id;
199+
int im_start_token_id;
200+
int im_end_token_id;
173201
};
174202

175203
void ChatHistoryEncoder::append_sys_prompt(std::vector<int> &ids) const
@@ -202,6 +230,54 @@ namespace m1
202230
ids.push_back(tok->c_a_token_id);
203231
}
204232

233+
static class ImChatHistoryEncoder : public BaseHistoryEncoder
234+
{
235+
public:
236+
void append_sys_prompt(std::vector<int> &ids) const override
237+
{
238+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
239+
240+
if (tok->get_system_prompt().size() > 0)
241+
{
242+
ids.push_back(tok->im_start_token_id);
243+
tok->encode("system\n", ids);
244+
tok->encode(tok->get_system_prompt(), ids);
245+
ids.push_back(tok->im_end_token_id);
246+
tok->encode("\n", ids);
247+
}
248+
}
249+
void append_ai(int round_idx, const std::string &ai, std::vector<int> &ids) const override
250+
{
251+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
252+
append_ai_opening(round_idx, ids);
253+
tok->encode(ai, ids);
254+
ids.push_back(tok->im_end_token_id);
255+
tok->encode("\n", ids);
256+
}
257+
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override
258+
{
259+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
260+
append_user_opening(round_idx, ids);
261+
tok->encode(user, ids);
262+
ids.push_back(tok->im_end_token_id);
263+
tok->encode("\n", ids);
264+
}
265+
266+
void append_ai_opening(int round_idx, std::vector<int> &ids) const override
267+
{
268+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
269+
ids.push_back(tok->im_start_token_id);
270+
tok->encode("assistant\n", ids);
271+
}
272+
273+
void append_user_opening(int round_idx, std::vector<int> &ids) const override
274+
{
275+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
276+
ids.push_back(tok->im_start_token_id);
277+
tok->encode("user\n", ids);
278+
}
279+
} _im_chat_encoder;
280+
205281
template <int sliding_window_len> class BaiChuanSWASelfAttention : public RoPESelfAttention<SlidingWindowAttentionImpl<sliding_window_len>>
206282
{
207283
public:
@@ -316,6 +392,18 @@ namespace m1
316392
<< "corrupted model weights: " << w_ctx_.get_used_mem() / ggml_tensor_overhead() << " != " << w_ctx_.get_mem_size() / ggml_tensor_overhead();
317393
}
318394

395+
void set_additional_args(const std::map<std::string, std::string> &args) override
396+
{
397+
auto it = args.find("chat_template");
398+
if (it != args.end())
399+
{
400+
if (it->second == "im")
401+
{
402+
tokenizer->set_chat_encoder(&_im_chat_encoder);
403+
}
404+
}
405+
}
406+
319407
void load(ModelLoader &loader) override
320408
{
321409
auto transformer = get_typed_transformer<ModelClass>();

models/kimi.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,20 @@ namespace vl
7171

7272
bool load_config(const json::JSON &config) override
7373
{
74-
auto cfg = config["tokenizer_config.json"];
75-
if (!cfg.IsObject()) return false;
76-
auto added_tokens_decoder = cfg["added_tokens_decoder"];
77-
if (!added_tokens_decoder.IsObject()) return false;
78-
79-
for (auto &item : added_tokens_decoder.ObjectRange())
80-
{
81-
#define check_token(tok) if ("<|" #tok "|>" == item.second["content"].ToString()) tok ## _token_id = std::stol(item.first)
82-
check_token(im_end);
83-
else check_token(im_user);
84-
else check_token(im_assistant);
85-
else check_token(im_system);
86-
else check_token(im_middle);
87-
else check_token(media_start);
88-
else check_token(media_content);
89-
else check_token(media_end);
90-
else check_token(media_pad);
91-
else;
92-
}
74+
#define check_token(tok) {std::string("<|" #tok "|>"), &(tok ## _token_id)}
75+
76+
load_added_tokens(config, {
77+
check_token(im_end),
78+
check_token(im_user),
79+
check_token(im_assistant),
80+
check_token(im_system),
81+
check_token(im_middle),
82+
check_token(media_start),
83+
check_token(media_content),
84+
check_token(media_end),
85+
check_token(media_pad),
86+
});
87+
#undef check_token
9388

9489
if (im_end_token_id >= 0)
9590
terminate_ids.insert(im_end_token_id);

scripts/richchat.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def callback_async_done(self) -> None:
7171

7272
llm: RichChatLLM = None
7373
MAX_THOUGHT_TIME = 60 * 3
74+
multiple_lines_input = False
7475

7576
def params_preprocess(params: list[str]) -> list[str]:
77+
global multiple_lines_input
78+
multiple_lines_input = '--multi' in params
7679
for i, s in enumerate(params):
7780
if (s == '--max-thought-time') and (i + 1 < len(params)):
7881
global MAX_THOUGHT_TIME
@@ -90,20 +93,31 @@ def handler(signal_received, frame):
9093
llm.show_meta('Statistics')
9194
sys.exit(0)
9295

96+
def user_input(prompt: str) -> str:
97+
global multiple_lines_input
98+
if multiple_lines_input:
99+
print(prompt, end='', flush=True)
100+
return sys.stdin.read()
101+
else:
102+
return input(prompt)
103+
93104
def demo_simple(params, lib_path: str, cls = RichChatLLM):
94105
global llm
106+
global multiple_lines_input
95107
signal.signal(signal.SIGINT, handler)
96108
llm = cls(LibChatLLM(lib_path), params)
97109

98110
llm.show_meta('Model')
111+
if multiple_lines_input:
112+
print('Press Ctrl+D / Ctrl+Z (Windows) to finish input')
99113

100114
render_ai = lambda: llm.render_ai()
101115
render_thoughts = lambda: llm.render_thoughts()
102116

103117
console = Console()
104118

105119
while True:
106-
s = input('You > ')
120+
s = user_input('You > ')
107121
if s == '': continue
108122

109123
if s.startswith('/start'):

src/chat.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,13 @@ namespace chatllm
350350
qa_encoder->set_tokenizer(this);
351351
}
352352

353+
void BaseTokenizer::set_chat_encoder(BaseHistoryEncoder *encoder)
354+
{
355+
chat_encoder = encoder;
356+
if (encoder)
357+
encoder->set_tokenizer(this);
358+
}
359+
353360
bool BaseTokenizer::is_terminate_token_id(int id) const
354361
{
355362
if (id == eos_token_id) return true;
@@ -543,6 +550,31 @@ namespace chatllm
543550
qa_encoder->skip_sys_prompt = skip;
544551
}
545552

553+
int BaseTokenizer::load_added_tokens(const json::JSON &config, std::initializer_list<std::pair<std::string, int *>> added_tokens)
554+
{
555+
int r = -1;
556+
auto cfg = config["tokenizer_config.json"];
557+
if (!cfg.IsObject()) return r;
558+
auto added_tokens_decoder = cfg["added_tokens_decoder"];
559+
if (!added_tokens_decoder.IsObject()) return r;
560+
561+
r = 0;
562+
563+
for (auto &item : added_tokens_decoder.ObjectRange())
564+
{
565+
for( auto tok = added_tokens.begin(), e = added_tokens.end(); tok != e; ++tok)
566+
{
567+
if (tok->first == item.second["content"].ToString())
568+
{
569+
*tok->second = std::stol(item.first);
570+
break;
571+
}
572+
}
573+
}
574+
575+
return r;
576+
}
577+
546578
void BaseHistoryEncoder::append_sys_prompt(std::vector<int> &ids) const
547579
{
548580
}

src/chat.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ namespace chatllm
192192

193193
virtual void set_skip_sys_prompt(bool skip);
194194

195+
void set_chat_encoder(BaseHistoryEncoder *encoder);
196+
195197
int bos_token_id;
196198
int eos_token_id;
197199
int pad_token_id;
@@ -203,6 +205,8 @@ namespace chatllm
203205
virtual std::string preprocess(const std::string &text) const;
204206
virtual std::string postprocess(const std::string &text) const;
205207

208+
int load_added_tokens(const json::JSON &config, std::initializer_list<std::pair<std::string, int *>> added_tokens);
209+
206210
public:
207211
tokenizer::Processor *tp;
208212
protected:

src/main.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ void usage(const std::string &prog)
233233
<< " --log_level log level. (default: 4 - ERROR)\n"
234234
<< " --serve_rpc [H:]P[@id] as a RPC server on host:port (optional: host default to 127.0.0.1, id defaults to 0) [#]\n"
235235
<< " --ggml_dir DIR specify directory of GGML\n"
236+
<< " --set KEY VALUE set a pair of additional args.\n"
236237
<< "Additional key-value args:\n"
237238
<< " --kv start of additional args. all following options are interpreted as k-v pairs\n"
238239
<< " key value a key-value pair of args\n"
@@ -377,6 +378,14 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
377378
args.detect_thoughts = true;
378379
}
379380
}
381+
else if (strcmp(arg, "--set") == 0)
382+
{
383+
if (c + 2 < argc)
384+
{
385+
args.additional[argv[c + 1]] = argv[c + 2];
386+
c += 2;
387+
}
388+
}
380389
handle_param("--model", "-m", model_path, std::string)
381390
handle_param("--prompt", "-p", prompt, std::string)
382391
handle_para0("--prompt_file", prompt, load_txt)

0 commit comments

Comments
 (0)